| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109 |
- from __future__ import annotations
- import base64
- import datetime
- import functools
- import http.client
- import json
- import logging
- import os
- import re
- import socket
- import sys
- import threading
- from collections.abc import Iterable, Mapping, MutableMapping, Sequence
- from copy import deepcopy
- from pathlib import Path
- from typing import (
- IO,
- TYPE_CHECKING,
- Any,
- Callable,
- Literal,
- NamedTuple,
- TextIO,
- Union,
- overload,
- )
- import click
- from wandb_gql import Client, gql
- from wandb_gql.client import RetryError
- from wandb_graphql.language.ast import Document
- import wandb
- from wandb import env, util
- from wandb.analytics import get_sentry
- from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messages
- from wandb.errors import AuthenticationError, CommError, UsageError
- from wandb.integration.sagemaker import parse_sm_secrets
- from wandb.proto.wandb_internal_pb2 import ServerFeature
- from wandb.sdk import wandb_setup
- from wandb.sdk.internal import settings_static
- from wandb.sdk.internal._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
- from wandb.sdk.lib.gql_request import GraphQLSession
- from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
- from ..lib import retry, wbauth
- from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
- from . import context
- from .progress import Progress
- logger = logging.getLogger(__name__)
- LAUNCH_DEFAULT_PROJECT = "model-registry"
- if TYPE_CHECKING:
- from typing import Literal, TypedDict
- import requests
- from .progress import ProgressFn
- class CreateArtifactFileSpecInput(TypedDict, total=False):
- """Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql."""
- artifactID: str
- name: str
- md5: str
- mimetype: str | None
- artifactManifestID: str | None
- uploadPartsInput: list[dict[str, object]] | None
- class CreateArtifactFilesResponseFile(TypedDict):
- id: str
- name: str
- displayName: str
- uploadUrl: str | None
- uploadHeaders: Sequence[str]
- uploadMultipartUrls: UploadPartsResponse
- storagePath: str
- artifact: CreateArtifactFilesResponseFileNode
- class CreateArtifactFilesResponseFileNode(TypedDict):
- id: str
- class UploadPartsResponse(TypedDict):
- uploadUrlParts: list[UploadUrlParts]
- uploadID: str
- class UploadUrlParts(TypedDict):
- partNumber: int
- uploadUrl: str
- class CompleteMultipartUploadArtifactInput(TypedDict):
- """Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql."""
- completeMultipartAction: str
- completedParts: dict[int, str]
- artifactID: str
- storagePath: str
- uploadID: str
- md5: str
- class CompleteMultipartUploadArtifactResponse(TypedDict):
- digest: str
- class DefaultSettings(TypedDict, total=False):
- section: str
- git_remote: str
- ignore_globs: list[str]
- base_url: str
- root_dir: str | None
- api_key: str | None
- entity: str | None
- organization: str | None
- project: str | None
- _extra_http_headers: Mapping[str, str] | None
- _proxies: Mapping[str, str] | None
- _Response = MutableMapping
- SweepState = Literal["RUNNING", "PAUSED", "CANCELED", "FINISHED"]
- Number = Union[int, float]
- httpclient_logger = logging.getLogger("http.client")
- if os.environ.get("WANDB_DEBUG"):
- httpclient_logger.setLevel(logging.DEBUG)
- def check_httpclient_logger_handler() -> None:
- # Only enable http.client logging if WANDB_DEBUG is set
- if not os.environ.get("WANDB_DEBUG"):
- return
- if httpclient_logger.handlers:
- return
- # Enable HTTPConnection debug logging to the logging framework
- level = logging.DEBUG
- def httpclient_log(*args: Any) -> None:
- httpclient_logger.log(level, " ".join(args))
- # mask the print() built-in in the http.client module to use logging instead
- http.client.print = httpclient_log # type: ignore[attr-defined]
- # enable debugging
- http.client.HTTPConnection.debuglevel = 1
- root_logger = logging.getLogger("wandb")
- if root_logger.handlers:
- httpclient_logger.addHandler(root_logger.handlers[0])
- class _ThreadLocalData(threading.local):
- context: context.Context | None
- def __init__(self) -> None:
- self.context = None
- class _OrgNames(NamedTuple):
- entity_name: str
- display_name: str
- def _match_org_with_fetched_org_entities(
- organization: str, orgs: Sequence[_OrgNames]
- ) -> str:
- """Match the organization provided in the path with the org entity or org name of the input entity.
- Args:
- organization: The organization name to match
- orgs: list of tuples containing (org_entity_name, org_display_name)
- Returns:
- str: The matched org entity name
- Raises:
- ValueError: If no matching organization is found or if multiple orgs exist without a match
- """
- for org_names in orgs:
- if organization in org_names:
- return org_names.entity_name
- if len(orgs) == 1:
- raise ValueError(
- f"Expecting the organization name or entity name to match {orgs[0].display_name!r} "
- f"and cannot be linked/fetched with {organization!r}. "
- "Please update the target path with the correct organization name."
- )
- raise ValueError(
- "Personal entity belongs to multiple organizations "
- f"and cannot be linked/fetched with {organization!r}. "
- "Please update the target path with the correct organization name "
- "or use a team entity in the entity settings."
- )
- class Api:
- """W&B Internal Api wrapper.
- Note:
- Settings are automatically overridden by looking for
- a `wandb/settings` file in the current working directory or its parent
- directory. If none can be found, we look in the current user's home
- directory.
- Args:
- default_settings(dict, optional): If you aren't using a settings
- file, or you wish to override the section to use in the settings file
- Override the settings here.
- """
- HTTP_TIMEOUT = env.get_http_timeout(20)
- FILE_PUSHER_TIMEOUT = env.get_file_pusher_timeout()
- _global_context: context.Context
- _local_data: _ThreadLocalData
- def __init__(
- self,
- default_settings: (
- wandb.Settings #
- | settings_static.SettingsStatic
- | DefaultSettings
- | None
- ) = None,
- load_settings: bool = True,
- retry_timedelta: datetime.timedelta | None = None,
- environ: MutableMapping[str, str] = os.environ,
- retry_callback: Callable[[int, str], Any] | None = None,
- api_key: str | None = None,
- ) -> None:
- import requests
- self._environ = environ
- self._global_context = context.Context()
- self._local_data = _ThreadLocalData()
- default_overrides: dict[str, Any] = (
- dict(default_settings) if default_settings else {}
- )
- self.default_settings: DefaultSettings = {
- "section": default_overrides.get("section", "default"),
- "git_remote": default_overrides.get("git_remote", "origin"),
- "ignore_globs": default_overrides.get("ignore_globs", []),
- "base_url": default_overrides.get("base_url", "https://api.wandb.ai"),
- "root_dir": default_overrides.get("root_dir"),
- "api_key": default_overrides.get("api_key"),
- "entity": default_overrides.get("entity"),
- "organization": default_overrides.get("organization"),
- "project": default_overrides.get("project"),
- "_extra_http_headers": default_overrides.get("_extra_http_headers"),
- "_proxies": default_overrides.get("_proxies"),
- }
- if load_settings:
- global_settings = wandb_setup.singleton().settings
- if root_dir := self.default_settings["root_dir"]:
- global_settings = global_settings.model_copy()
- global_settings.root_dir = root_dir
- self._settings = global_settings.read_system_settings().all()
- else:
- self._settings = {}
- # Mutable settings set by the _file_stream_api
- self.dynamic_settings = {
- "system_sample_seconds": 2,
- "system_samples": 15,
- "heartbeat_seconds": 30,
- }
- self.retry_timedelta = retry_timedelta or datetime.timedelta(days=7)
- self.retry_uploads = 10
- # todo: remove these hacky hacks after settings refactor is complete
- # keeping this code here to limit scope and so that it is easy to remove later
- self._extra_http_headers = self.settings("_extra_http_headers") or json.loads(
- self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
- )
- auth = None
- api_key = api_key or self.default_settings.get("api_key")
- if api_key:
- auth = ("api", api_key)
- elif self.access_token is not None:
- self._extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
- else:
- auth = ("api", self.api_key or "")
- proxies = self.settings("_proxies") or json.loads(
- self._environ.get("WANDB__PROXIES", "{}")
- )
- self.client = Client(
- transport=GraphQLSession(
- headers={
- "User-Agent": self.user_agent,
- "X-WANDB-USERNAME": env.get_username(env=self._environ),
- "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
- **self._extra_http_headers,
- },
- use_json=True,
- # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
- # https://bugs.python.org/issue22889
- timeout=self.HTTP_TIMEOUT,
- auth=auth,
- url=f"{self.settings('base_url')}/graphql",
- proxies=proxies,
- )
- )
- self.retry_callback = retry_callback
- self._retry_gql = retry.Retry(
- self.execute,
- retry_timedelta=retry_timedelta,
- check_retry_fn=util.no_retry_auth,
- retryable_exceptions=(RetryError, requests.RequestException),
- retry_callback=retry_callback,
- )
- self._current_run_id: str | None = None
- self._file_stream_api = None
- self._upload_file_session = requests.Session()
- if self.FILE_PUSHER_TIMEOUT:
- self._upload_file_session.put = functools.partial( # type: ignore
- self._upload_file_session.put,
- timeout=self.FILE_PUSHER_TIMEOUT,
- )
- if proxies:
- self._upload_file_session.proxies.update(proxies)
- # This Retry class is initialized once for each Api instance, so this
- # defaults to retrying 1 million times per process or 7 days
- self.upload_file_retry = normalize_exceptions(
- retry.retriable(retry_timedelta=retry_timedelta)(self.upload_file)
- )
- self.upload_multipart_file_chunk_retry = normalize_exceptions(
- retry.retriable(retry_timedelta=retry_timedelta)(
- self.upload_multipart_file_chunk
- )
- )
- self._client_id_mapping: dict[str, str] = {}
- # Large file uploads to azure can optionally use their SDK
- self._azure_blob_module = util.get_module("azure.storage.blob")
- self._max_cli_version: str | None = None
- self._server_features_cache: dict[str, bool] | None = None
- def gql(self, *args: Any, **kwargs: Any) -> Any:
- ret = self._retry_gql(
- *args,
- retry_cancel_event=self.context.cancel_event,
- **kwargs,
- )
- return ret
- def set_local_context(self, api_context: context.Context | None) -> None:
- self._local_data.context = api_context
- def clear_local_context(self) -> None:
- self._local_data.context = None
- @property
- def context(self) -> context.Context:
- return self._local_data.context or self._global_context
- def reauth(self) -> None:
- """Ensure the current api key is set in the transport."""
- self.client.transport.session.auth = ("api", self.api_key or "")
- def relocate(self) -> None:
- """Ensure the current api points to the right server."""
- self.client.transport.url = "{}/graphql".format(self.settings("base_url"))
- def execute(self, *args: Any, **kwargs: Any) -> _Response:
- """Wrapper around execute that logs in cases of failure."""
- import requests
- try:
- return self.client.execute(*args, **kwargs) # type: ignore
- except requests.exceptions.HTTPError as err:
- response = err.response
- assert response is not None
- logger.exception("Error executing GraphQL.")
- for error in parse_backend_error_messages(response):
- wandb.termerror(f"Error while calling W&B API: {error} ({response})")
- raise
- def validate_api_key(self) -> bool:
- """Returns whether the API key stored on initialization is valid."""
- res = self.gql(gql("query { viewer { id } }"))
- return res is not None and res["viewer"] is not None
- def set_current_run_id(self, run_id: str) -> None:
- self._current_run_id = run_id
- @property
- def current_run_id(self) -> str | None:
- return self._current_run_id
- @property
- def user_agent(self) -> str:
- return f"W&B Internal Client {wandb.__version__}"
- @property
- def api_key(self) -> str | None:
- from wandb.sdk.lib import wbauth
- if ( #
- (auth := wbauth.session_credentials(host=self.api_url))
- and isinstance(auth, wbauth.AuthApiKey)
- ):
- return auth.api_key
- return (
- os.getenv(env.API_KEY)
- or wbauth.read_netrc_auth(host=self.api_url)
- or parse_sm_secrets().get(env.API_KEY)
- or self.default_settings.get("api_key")
- )
- @property
- def access_token(self) -> str | None:
- """Retrieves an access token for authentication.
- This function attempts to exchange an identity token for a temporary
- access token from the server, and save it to the credentials file.
- It uses the path to the identity token as defined in the environment
- variables. If the environment variable is not set, it returns None.
- Returns:
- str | None: The access token if available, otherwise None if
- no identity token is supplied.
- Raises:
- AuthenticationError: If the path to the identity token is not found.
- """
- token_file_str = self._environ.get(env.IDENTITY_TOKEN_FILE)
- if not token_file_str:
- return None
- token_file = Path(token_file_str)
- if not token_file.exists():
- raise AuthenticationError(f"Identity token file not found: {token_file}")
- auth = wbauth.AuthIdentityTokenFile(
- host=self.settings("base_url"),
- path=str(token_file),
- credentials_file=wandb_setup.singleton().settings.credentials_file,
- )
- return auth.fetch_access_token()
- @property
- def api_url(self) -> str:
- return self.settings("base_url") # type: ignore
- @property
- def app_url(self) -> str:
- return wandb.util.app_url(self.api_url)
- @property
- def default_entity(self) -> str:
- return self.viewer().get("entity") # type: ignore
- @overload
- def settings(self, key: None = None) -> dict[str, Any]: ...
- @overload
- def settings(self, key: str) -> Any: ...
- def settings(self, key: str | None = None) -> Any:
- """The settings overridden from the wandb/settings file.
- Args:
- key (str, optional): If provided only this setting is returned
- section (str, optional): If provided this section of the setting file is
- used, defaults to "default"
- Returns:
- A dict with the current settings
- {
- "entity": "models",
- "base_url": "https://api.wandb.ai",
- "project": None,
- "organization": "my-org",
- }
- """
- result: dict[str, Any] = dict(self.default_settings)
- result.update(self._settings)
- result.update(
- {
- "entity": env.get_entity(
- self._settings.get(
- "entity",
- result.get("entity"),
- ),
- env=self._environ,
- ),
- "organization": env.get_organization(
- self._settings.get(
- "organization",
- result.get("organization"),
- ),
- env=self._environ,
- ),
- "project": env.get_project(
- self._settings.get(
- "project",
- result.get("project"),
- ),
- env=self._environ,
- ),
- "base_url": env.get_base_url(
- self._settings.get(
- "base_url",
- result.get("base_url"),
- ),
- env=self._environ,
- ),
- }
- )
- return result if key is None else result[key]
- def clear_setting(self, key: str) -> None:
- self._settings.pop(key, None)
- def set_setting(self, key: str, value: Any) -> None:
- self._settings[key] = value
- if key == "entity":
- env.set_entity(value, env=self._environ)
- elif key == "project":
- env.set_project(value, env=self._environ)
- elif key == "base_url":
- self.relocate()
- def parse_slug(
- self, slug: str, project: str | None = None, run: str | None = None
- ) -> tuple[str, str]:
- """Parse a slug into a project and run.
- Args:
- slug (str): The slug to parse
- project (str, optional): The project to use, if not provided it will be
- inferred from the slug
- run (str, optional): The run to use, if not provided it will be inferred
- from the slug
- Returns:
- A dict with the project and run
- """
- if slug and "/" in slug:
- parts = slug.split("/")
- project = parts[0]
- run = parts[1]
- else:
- project = project or self.settings().get("project")
- if project is None:
- raise CommError("No default project configured.")
- run = run or slug or self.current_run_id or env.get_run(env=self._environ)
- assert run, "run must be specified"
- return project, run
- @normalize_exceptions
- def fail_run_queue_item(
- self,
- run_queue_item_id: str,
- message: str,
- stage: str,
- file_paths: list[str] | None = None,
- ) -> bool:
- variable_values: dict[str, str | (list[str] | None)] = {
- "runQueueItemId": run_queue_item_id,
- "message": message,
- "stage": stage,
- }
- if file_paths is not None:
- variable_values["filePaths"] = file_paths
- mutation_string = """
- mutation failRunQueueItem($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
- failRunQueueItem(
- input: {
- runQueueItemId: $runQueueItemId
- message: $message
- stage: $stage
- filePaths: $filePaths
- }
- ) {
- success
- }
- }
- """
- mutation = gql(mutation_string)
- response = self.gql(mutation, variable_values=variable_values)
- result: bool = response["failRunQueueItem"]["success"]
- return result
- def _server_features(self) -> dict[str, bool]:
- # NOTE: Avoid caching via `@cached_property`, due to undocumented
- # locking behavior before Python 3.12.
- # See: https://github.com/python/cpython/issues/87634
- query = gql(SERVER_FEATURES_QUERY_GQL)
- try:
- response = self.gql(query)
- except Exception as e:
- # Unfortunately we currently have to match on the text of the error message,
- # as the `gql` client raises `Exception` rather than a more specific error.
- if 'Cannot query field "features" on type "ServerInfo".' in str(e):
- self._server_features_cache = {}
- else:
- raise
- else:
- info = ServerFeaturesQuery.model_validate(response).server_info
- if info and (feats := info.features):
- self._server_features_cache = {f.name: f.is_enabled for f in feats if f}
- else:
- self._server_features_cache = {}
- return self._server_features_cache
- def _server_supports(self, feature: int | str) -> bool:
- """Return whether the current server supports the given feature.
- NOTE: This is deprecated. Outside of this file, please use
- `ServiceApi.feature_enabled()`. The `ServiceApi` is a sort of
- replacement to this "internal" `Api` class.
- This also caches the underlying lookup of server feature flags,
- and it maps {feature_name (str) -> is_enabled (bool)}.
- Good to use for features that have a fallback mechanism for older servers.
- """
- # If we're given the protobuf enum value, convert to a string name.
- # NOTE: We deliberately use names (str) instead of enum values (int)
- # as the keys here, since:
- # - the server identifies features by their name, rather than (client-side) enum value
- # - the defined list of client-side flags may be behind the server-side list of flags
- key = ServerFeature.Name(feature) if isinstance(feature, int) else feature
- return self._server_features().get(key) or False
- @normalize_exceptions
- def update_run_queue_item_warning(
- self,
- run_queue_item_id: str,
- message: str,
- stage: str,
- file_paths: list[str] | None = None,
- ) -> bool:
- mutation = gql(
- """
- mutation updateRunQueueItemWarning($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
- updateRunQueueItemWarning(
- input: {
- runQueueItemId: $runQueueItemId
- message: $message
- stage: $stage
- filePaths: $filePaths
- }
- ) {
- success
- }
- }
- """
- )
- response = self.gql(
- mutation,
- variable_values={
- "runQueueItemId": run_queue_item_id,
- "message": message,
- "stage": stage,
- "filePaths": file_paths,
- },
- )
- result: bool = response["updateRunQueueItemWarning"]["success"]
- return result
- @normalize_exceptions
- def viewer(self) -> dict[str, Any]:
- query = gql(
- """
- query Viewer{
- viewer {
- id
- entity
- username
- flags
- teams {
- edges {
- node {
- name
- }
- }
- }
- }
- }
- """
- )
- res = self.gql(query)
- return res.get("viewer") or {}
- @normalize_exceptions
- def max_cli_version(self) -> str | None:
- if self._max_cli_version is not None:
- return self._max_cli_version
- _, server_info = self.viewer_server_info()
- self._max_cli_version = server_info.get("cliVersionInfo", {}).get(
- "max_cli_version"
- )
- return self._max_cli_version
- @normalize_exceptions
- def viewer_server_info(self) -> tuple[dict[str, Any], dict[str, Any]]:
- query = gql(
- """
- query Viewer{
- viewer {
- id
- entity
- username
- email
- flags
- teams {
- edges {
- node {
- name
- }
- }
- }
- }
- serverInfo {
- cliVersionInfo
- latestLocalVersionInfo {
- outOfDate
- latestVersionString
- versionOnThisInstanceString
- }
- }
- }
- """
- )
- res = self.gql(query)
- return res.get("viewer") or {}, res.get("serverInfo") or {}
- @normalize_exceptions
- def list_projects(self, entity: str | None = None) -> list[dict[str, str]]:
- """List projects in W&B scoped by entity.
- Args:
- entity (str, optional): The entity to scope this project to.
- Returns:
- [{"id","name","description"}]
- """
- query = gql(
- """
- query EntityProjects($entity: String) {
- models(first: 10, entityName: $entity) {
- edges {
- node {
- id
- name
- description
- }
- }
- }
- }
- """
- )
- project_list: list[dict[str, str]] = self._flatten_edges(
- self.gql(
- query, variable_values={"entity": entity or self.settings("entity")}
- )["models"]
- )
- return project_list
- @normalize_exceptions
- def project(self, project: str, entity: str | None = None) -> _Response:
- """Retrieve project.
- Args:
- project (str): The project to get details for
- entity (str, optional): The entity to scope this project to.
- Returns:
- [{"id","name","repo","dockerImage","description"}]
- """
- query = gql(
- """
- query ProjectDetails($entity: String, $project: String) {
- model(name: $project, entityName: $entity) {
- id
- name
- repo
- dockerImage
- description
- }
- }
- """
- )
- response: _Response = self.gql(
- query, variable_values={"entity": entity, "project": project}
- )["model"]
- return response
- @normalize_exceptions
- def sweep(
- self,
- sweep: str,
- specs: str,
- project: str | None = None,
- entity: str | None = None,
- ) -> dict[str, Any]:
- """Retrieve sweep.
- Args:
- sweep (str): The sweep to get details for
- specs (str): history specs
- project (str, optional): The project to scope this sweep to.
- entity (str, optional): The entity to scope this sweep to.
- Returns:
- [{"id","name","repo","dockerImage","description"}]
- """
- query = gql(
- """
- query SweepWithRuns($entity: String, $project: String, $sweep: String!, $specs: [JSONString!]!) {
- project(name: $project, entityName: $entity) {
- sweep(sweepName: $sweep) {
- id
- name
- method
- state
- description
- config
- createdAt
- heartbeatAt
- updatedAt
- earlyStopJobRunning
- bestLoss
- controller
- scheduler
- runs {
- edges {
- node {
- name
- state
- config
- exitcode
- heartbeatAt
- shouldStop
- failed
- stopped
- running
- summaryMetrics
- sampledHistory(specs: $specs)
- }
- }
- }
- }
- }
- }
- """
- )
- entity = entity or self.settings("entity")
- project = project or self.settings("project")
- response = self.gql(
- query,
- variable_values={
- "entity": entity,
- "project": project,
- "sweep": sweep,
- "specs": specs,
- },
- )
- if response["project"] is None or response["project"]["sweep"] is None:
- raise ValueError(f"Sweep {entity}/{project}/{sweep} not found")
- data: dict[str, Any] = response["project"]["sweep"]
- if data:
- data["runs"] = self._flatten_edges(data["runs"])
- return data
- @normalize_exceptions
- def list_runs(
- self, project: str, entity: str | None = None
- ) -> list[dict[str, str]]:
- """List runs in W&B scoped by project.
- Args:
- project (str): The project to scope the runs to
- entity (str, optional): The entity to scope this project to. Defaults to public models
- Returns:
- [{"id","name","description"}]
- """
- query = gql(
- """
- query ProjectRuns($model: String!, $entity: String) {
- model(name: $model, entityName: $entity) {
- buckets(first: 10) {
- edges {
- node {
- id
- name
- displayName
- description
- }
- }
- }
- }
- }
- """
- )
- return self._flatten_edges(
- self.gql(
- query,
- variable_values={
- "entity": entity or self.settings("entity"),
- "model": project or self.settings("project"),
- },
- )["model"]["buckets"]
- )
- @normalize_exceptions
- def run_config(
- self, project: str, run: str | None = None, entity: str | None = None
- ) -> tuple[str, dict[str, Any], str | None, dict[str, Any]]:
- """Get the relevant configs for a run.
- Args:
- project (str): The project to download, (can include bucket)
- run (str, optional): The run to download
- entity (str, optional): The entity to scope this project to.
- """
- import requests
- check_httpclient_logger_handler()
- query = gql(
- """
- query RunConfigs(
- $name: String!,
- $entity: String,
- $run: String!,
- $pattern: String!,
- $includeConfig: Boolean!,
- ) {
- model(name: $name, entityName: $entity) {
- bucket(name: $run) {
- config @include(if: $includeConfig)
- commit @include(if: $includeConfig)
- files(pattern: $pattern) {
- pageInfo {
- hasNextPage
- endCursor
- }
- edges {
- node {
- name
- directUrl
- }
- }
- }
- }
- }
- }
- """
- )
- variable_values = {
- "name": project,
- "run": run,
- "entity": entity,
- "includeConfig": True,
- }
- commit: str = ""
- config: dict[str, Any] = {}
- patch: str | None = None
- metadata: dict[str, Any] = {}
- # If we use the `names` parameter on the `files` node, then the server
- # will helpfully give us and 'open' file handle to the files that don't
- # exist. This is so that we can upload data to it. However, in this
- # case, we just want to download that file and not upload to it, so
- # let's instead query for the files that do exist using `pattern`
- # (with no wildcards).
- #
- # Unfortunately we're unable to construct a single pattern that matches
- # our 2 files, we would need something like regex for that.
- for filename in [DIFF_FNAME, METADATA_FNAME]:
- variable_values["pattern"] = filename
- response = self.gql(query, variable_values=variable_values)
- if response["model"] is None:
- raise CommError(f"Run {entity}/{project}/{run} not found")
- run_obj: dict = response["model"]["bucket"]
- # we only need to fetch this config once
- if variable_values["includeConfig"]:
- commit = run_obj["commit"]
- config = json.loads(run_obj["config"] or "{}")
- variable_values["includeConfig"] = False
- if run_obj["files"] is not None:
- for file_edge in run_obj["files"]["edges"]:
- name = file_edge["node"]["name"]
- url = file_edge["node"]["directUrl"]
- res = requests.get(url)
- res.raise_for_status()
- if name == METADATA_FNAME:
- metadata = res.json()
- elif name == DIFF_FNAME:
- patch = res.text
- return commit, config, patch, metadata
- @normalize_exceptions
- def run_resume_status(
- self, entity: str, project_name: str, name: str
- ) -> dict[str, Any] | None:
- """Check if a run exists and get resume information.
- Args:
- entity (str): The entity to scope this project to.
- project_name (str): The project to download, (can include bucket)
- name (str): The run to download
- """
- # Pulling wandbConfig.start_time is required so that we can determine if a run has actually started
- query = gql(
- """
- query RunResumeStatus($project: String, $entity: String, $name: String!) {
- model(name: $project, entityName: $entity) {
- id
- name
- entity {
- id
- name
- }
- bucket(name: $name, missingOk: true) {
- id
- name
- summaryMetrics
- displayName
- logLineCount
- historyLineCount
- eventsLineCount
- historyTail
- eventsTail
- config
- tags
- wandbConfig(keys: ["t"])
- }
- }
- }
- """
- )
- response = self.gql(
- query,
- variable_values={
- "entity": entity,
- "project": project_name,
- "name": name,
- },
- )
- if "model" not in response or "bucket" not in (response["model"] or {}):
- return None
- project = response["model"]
- self.set_setting("project", project_name)
- if "entity" in project:
- self.set_setting("entity", project["entity"]["name"])
- result: dict[str, Any] = project["bucket"]
- return result
- @normalize_exceptions
- def check_stop_requested(
- self, project_name: str, entity_name: str, run_id: str
- ) -> bool:
- query = gql(
- """
- query RunStoppedStatus($projectName: String, $entityName: String, $runId: String!) {
- project(name:$projectName, entityName:$entityName) {
- run(name:$runId) {
- stopped
- }
- }
- }
- """
- )
- response = self.gql(
- query,
- variable_values={
- "projectName": project_name,
- "entityName": entity_name,
- "runId": run_id,
- },
- )
- project = response.get("project", None)
- if not project:
- return False
- run = project.get("run", None)
- if not run:
- return False
- status: bool = run["stopped"]
- return status
- def format_project(self, project: str) -> str:
- return re.sub(r"\W+", "-", project.lower()).strip("-_")
- @normalize_exceptions
- def upsert_project(
- self,
- project: str,
- id: str | None = None,
- description: str | None = None,
- entity: str | None = None,
- ) -> dict[str, Any]:
- """Create a new project.
- Args:
- project (str): The project to create
- description (str, optional): A description of this project
- entity (str, optional): The entity to scope this project to.
- """
- mutation = gql(
- """
- mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String) {
- upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
- model {
- name
- description
- }
- }
- }
- """
- )
- response = self.gql(
- mutation,
- variable_values={
- "name": self.format_project(project),
- "entity": entity or self.settings("entity"),
- "description": description,
- "id": id,
- },
- )
- result: dict[str, Any] = response["upsertModel"]["model"]
- return result
- @normalize_exceptions
- def entity_is_team(self, entity: str) -> bool:
- query = gql(
- """
- query EntityIsTeam($entity: String!) {
- entity(name: $entity) {
- id
- isTeam
- }
- }
- """
- )
- variable_values = {
- "entity": entity,
- }
- res = self.gql(query, variable_values)
- if res.get("entity") is None:
- raise Exception(
- f"Error fetching entity {entity} "
- "check that you have access to this entity"
- )
- is_team: bool = res["entity"]["isTeam"]
- return is_team
- @normalize_exceptions
- def get_project_run_queues(self, entity: str, project: str) -> list[dict[str, str]]:
- query = gql(
- """
- query ProjectRunQueues($entity: String!, $projectName: String!){
- project(entityName: $entity, name: $projectName) {
- runQueues {
- id
- name
- createdBy
- access
- }
- }
- }
- """
- )
- variable_values = {
- "projectName": project,
- "entity": entity,
- }
- res = self.gql(query, variable_values)
- if res.get("project") is None:
- # circular dependency: (LAUNCH_DEFAULT_PROJECT = model-registry)
- if project == "model-registry":
- msg = (
- f"Error fetching run queues for {entity} "
- "check that you have access to this entity and project"
- )
- else:
- msg = (
- f"Error fetching run queues for {entity}/{project} "
- "check that you have access to this entity and project"
- )
- raise Exception(msg)
- project_run_queues: list[dict[str, str]] = res["project"]["runQueues"]
- return project_run_queues
- @normalize_exceptions
- def create_default_resource_config(
- self,
- entity: str,
- resource: str,
- config: str,
- template_variables: dict[str, float | int | str] | None,
- ) -> dict[str, Any] | None:
- mutation_params = """
- $entityName: String!,
- $resource: String!,
- $config: JSONString!,
- $templateVariables: JSONString
- """
- mutation_inputs = """
- entityName: $entityName,
- resource: $resource,
- config: $config,
- templateVariables: $templateVariables
- """
- variable_values = {
- "entityName": entity,
- "resource": resource,
- "config": config,
- }
- if template_variables is not None:
- variable_values["templateVariables"] = json.dumps(template_variables)
- else:
- variable_values["templateVariables"] = "{}"
- query = gql(
- f"""
- mutation createDefaultResourceConfig(
- {mutation_params}
- ) {{
- createDefaultResourceConfig(
- input: {{
- {mutation_inputs}
- }}
- ) {{
- defaultResourceConfigID
- success
- }}
- }}
- """
- )
- result: dict[str, Any] | None = self.gql(query, variable_values)[
- "createDefaultResourceConfig"
- ]
- return result
- @normalize_exceptions
- def create_run_queue(
- self,
- entity: str,
- project: str,
- queue_name: str,
- access: str,
- prioritization_mode: str | None = None,
- config_id: str | None = None,
- ) -> dict[str, Any] | None:
- query = gql(
- """
- mutation createRunQueue(
- $entity: String!,
- $project: String!,
- $queueName: String!,
- $access: RunQueueAccessType!,
- $prioritizationMode: RunQueuePrioritizationMode,
- $defaultResourceConfigID: ID,
- ) {
- createRunQueue(
- input: {
- entityName: $entity,
- projectName: $project,
- queueName: $queueName,
- access: $access,
- prioritizationMode: $prioritizationMode
- defaultResourceConfigID: $defaultResourceConfigID
- }
- ) {
- success
- queueID
- }
- }
- """
- )
- variable_values = {
- "entity": entity,
- "project": project,
- "queueName": queue_name,
- "access": access,
- "prioritizationMode": prioritization_mode,
- "defaultResourceConfigID": config_id,
- }
- result: dict[str, Any] | None = self.gql(query, variable_values)[
- "createRunQueue"
- ]
- return result
- @normalize_exceptions
- def upsert_run_queue(
- self,
- queue_name: str,
- entity: str,
- resource_type: str,
- resource_config: dict,
- project: str = LAUNCH_DEFAULT_PROJECT,
- prioritization_mode: str | None = None,
- template_variables: dict | None = None,
- external_links: dict | None = None,
- ) -> dict[str, Any] | None:
- query = gql(
- """
- mutation upsertRunQueue(
- $entityName: String!
- $projectName: String!
- $queueName: String!
- $resourceType: String!
- $resourceConfig: JSONString!
- $templateVariables: JSONString
- $prioritizationMode: RunQueuePrioritizationMode
- $externalLinks: JSONString
- $clientMutationId: String
- ) {
- upsertRunQueue(
- input: {
- entityName: $entityName
- projectName: $projectName
- queueName: $queueName
- resourceType: $resourceType
- resourceConfig: $resourceConfig
- templateVariables: $templateVariables
- prioritizationMode: $prioritizationMode
- externalLinks: $externalLinks
- clientMutationId: $clientMutationId
- }
- ) {
- success
- configSchemaValidationErrors
- }
- }
- """
- )
- variable_values = {
- "entityName": entity,
- "projectName": project,
- "queueName": queue_name,
- "resourceType": resource_type,
- "resourceConfig": json.dumps(resource_config),
- "templateVariables": (
- json.dumps(template_variables) if template_variables else None
- ),
- "prioritizationMode": prioritization_mode,
- "externalLinks": json.dumps(external_links) if external_links else None,
- }
- result: dict[str, Any] = self.gql(query, variable_values)
- return result["upsertRunQueue"]
- @normalize_exceptions
- def push_to_run_queue_by_name(
- self,
- entity: str,
- project: str,
- queue_name: str,
- run_spec: str,
- template_variables: dict[str, int | float | str] | None,
- priority: int | None = None,
- ) -> dict[str, Any] | None:
- mutation_params = """
- $entityName: String!,
- $projectName: String!,
- $queueName: String!,
- $runSpec: JSONString!
- """
- mutation_input = """
- entityName: $entityName,
- projectName: $projectName,
- queueName: $queueName,
- runSpec: $runSpec
- """
- variables: dict[str, Any] = {
- "entityName": entity,
- "projectName": project,
- "queueName": queue_name,
- "runSpec": run_spec,
- }
- if priority is not None:
- variables["priority"] = priority
- mutation_params += ", $priority: Int"
- mutation_input += ", priority: $priority"
- if template_variables is not None:
- variables.update({"templateVariableValues": json.dumps(template_variables)})
- mutation_params += ", $templateVariableValues: JSONString"
- mutation_input += ", templateVariableValues: $templateVariableValues"
- mutation = gql(
- f"""
- mutation pushToRunQueueByName(
- {mutation_params}
- ) {{
- pushToRunQueueByName(
- input: {{
- {mutation_input}
- }}
- ) {{
- runQueueItemId
- runSpec
- }}
- }}
- """
- )
- try:
- result: dict[str, Any] | None = self.gql(
- mutation, variables, check_retry_fn=util.no_retry_4xx
- ).get("pushToRunQueueByName")
- if not result:
- return None
- if result.get("runSpec"):
- run_spec = json.loads(str(result["runSpec"]))
- result["runSpec"] = run_spec
- return result
- except Exception as e:
- if (
- 'Cannot query field "runSpec" on type "PushToRunQueueByNamePayload"'
- not in str(e)
- ):
- return None
- mutation_no_runspec = gql(
- """
- mutation pushToRunQueueByName(
- $entityName: String!,
- $projectName: String!,
- $queueName: String!,
- $runSpec: JSONString!,
- ) {
- pushToRunQueueByName(
- input: {
- entityName: $entityName,
- projectName: $projectName,
- queueName: $queueName,
- runSpec: $runSpec
- }
- ) {
- runQueueItemId
- }
- }
- """
- )
- try:
- result = self.gql(
- mutation_no_runspec, variables, check_retry_fn=util.no_retry_4xx
- ).get("pushToRunQueueByName")
- except Exception:
- result = None
- return result
- @normalize_exceptions
- def push_to_run_queue(
- self,
- queue_name: str,
- launch_spec: dict[str, str],
- template_variables: dict | None,
- project_queue: str,
- priority: int | None = None,
- ) -> dict[str, Any] | None:
- entity = launch_spec.get("queue_entity") or launch_spec["entity"]
- run_spec = json.dumps(launch_spec)
- push_result = self.push_to_run_queue_by_name(
- entity, project_queue, queue_name, run_spec, template_variables, priority
- )
- if push_result:
- return push_result
- if priority is not None:
- # Cannot proceed with legacy method if priority is set
- return None
- """ Legacy Method """
- queues_found = self.get_project_run_queues(entity, project_queue)
- matching_queues = [
- q
- for q in queues_found
- if q["name"] == queue_name
- # ensure user has access to queue
- and (
- # TODO: User created queues in the UI have USER access
- q["access"] in ["PROJECT", "USER"]
- or q["createdBy"] == self.default_entity
- )
- ]
- if not matching_queues:
- # in the case of a missing default queue. create it
- if queue_name == "default":
- wandb.termlog(
- f"No default queue existing for entity: {entity} in project: {project_queue}, creating one."
- )
- res = self.create_run_queue(
- launch_spec["entity"],
- project_queue,
- queue_name,
- access="PROJECT",
- )
- if res is None or res.get("queueID") is None:
- wandb.termerror(
- f"Unable to create default queue for entity: {entity} on project: {project_queue}. Run could not be added to a queue"
- )
- return None
- queue_id = res["queueID"]
- else:
- if project_queue == "model-registry":
- _msg = f"Unable to push to run queue {queue_name}. Queue not found."
- else:
- _msg = f"Unable to push to run queue {project_queue}/{queue_name}. Queue not found."
- wandb.termwarn(_msg)
- return None
- elif len(matching_queues) > 1:
- wandb.termerror(
- f"Unable to push to run queue {queue_name}. More than one queue found with this name."
- )
- return None
- else:
- queue_id = matching_queues[0]["id"]
- spec_json = json.dumps(launch_spec)
- variables = {"queueID": queue_id, "runSpec": spec_json}
- mutation_params = """
- $queueID: ID!,
- $runSpec: JSONString!
- """
- mutation_input = """
- queueID: $queueID,
- runSpec: $runSpec
- """
- if template_variables is not None:
- mutation_params += ", $templateVariableValues: JSONString"
- mutation_input += ", templateVariableValues: $templateVariableValues"
- variables.update({"templateVariableValues": json.dumps(template_variables)})
- mutation = gql(
- f"""
- mutation pushToRunQueue(
- {mutation_params}
- ) {{
- pushToRunQueue(
- input: {{{mutation_input}}}
- ) {{
- runQueueItemId
- }}
- }}
- """
- )
- response = self.gql(mutation, variable_values=variables)
- if not response.get("pushToRunQueue"):
- raise CommError(f"Error pushing run queue item to queue {queue_name}.")
- result: dict[str, Any] | None = response["pushToRunQueue"]
- return result
- @normalize_exceptions
- def pop_from_run_queue(
- self,
- queue_name: str,
- entity: str | None = None,
- project: str | None = None,
- agent_id: str | None = None,
- ) -> dict[str, Any] | None:
- mutation = gql(
- """
- mutation popFromRunQueue($entity: String!, $project: String!, $queueName: String!, $launchAgentId: ID) {
- popFromRunQueue(input: {
- entityName: $entity,
- projectName: $project,
- queueName: $queueName,
- launchAgentId: $launchAgentId
- }) {
- runQueueItemId
- runSpec
- }
- }
- """
- )
- response = self.gql(
- mutation,
- variable_values={
- "entity": entity,
- "project": project,
- "queueName": queue_name,
- "launchAgentId": agent_id,
- },
- )
- result: dict[str, Any] | None = response["popFromRunQueue"]
- return result
- @normalize_exceptions
- def ack_run_queue_item(self, item_id: str, run_id: str | None = None) -> bool:
- mutation = gql(
- """
- mutation ackRunQueueItem($itemId: ID!, $runId: String!) {
- ackRunQueueItem(input: { runQueueItemId: $itemId, runName: $runId }) {
- success
- }
- }
- """
- )
- response = self.gql(
- mutation, variable_values={"itemId": item_id, "runId": str(run_id)}
- )
- if not response["ackRunQueueItem"]["success"]:
- raise CommError(
- "Error acking run queue item. Item may have already been acknowledged by another process"
- )
- result: bool = response["ackRunQueueItem"]["success"]
- return result
- @normalize_exceptions
- def create_launch_agent(
- self,
- entity: str,
- project: str,
- queues: list[str],
- agent_config: dict[str, Any],
- version: str,
- ) -> dict:
- project_queues = self.get_project_run_queues(entity, project)
- if not project_queues:
- # create default queue if it doesn't already exist
- default = self.create_run_queue(
- entity, project, "default", access="PROJECT"
- )
- if default is None or default.get("queueID") is None:
- raise CommError(
- f"Unable to create default queue for {entity}/{project}. No queues for agent to poll"
- )
- project_queues = [{"id": default["queueID"], "name": "default"}]
- polling_queue_ids = [
- q["id"] for q in project_queues if q["name"] in queues
- ] # filter to poll specified queues
- if len(polling_queue_ids) != len(queues):
- raise CommError(
- f"Could not start launch agent: Not all of requested queues ({', '.join(queues)}) found. "
- f"Available queues for this project: {','.join([q['name'] for q in project_queues])}"
- )
- hostname = socket.gethostname()
- variable_values = {
- "entity": entity,
- "project": project,
- "queues": polling_queue_ids,
- "hostname": hostname,
- "agentConfig": json.dumps(agent_config),
- "version": version,
- }
- mutation_params = """
- $entity: String!,
- $project: String!,
- $queues: [ID!]!,
- $hostname: String!,
- $agentConfig: JSONString,
- $version: String
- """
- mutation_input = """
- entityName: $entity,
- projectName: $project,
- runQueues: $queues,
- hostname: $hostname,
- agentConfig: $agentConfig,
- version: $version
- """
- mutation = gql(
- f"""
- mutation createLaunchAgent(
- {mutation_params}
- ) {{
- createLaunchAgent(
- input: {{
- {mutation_input}
- }}
- ) {{
- launchAgentId
- }}
- }}
- """
- )
- result: dict = self.gql(mutation, variable_values)["createLaunchAgent"]
- return result
- @normalize_exceptions
- def update_launch_agent_status(
- self,
- agent_id: str,
- status: str,
- ) -> dict:
- mutation = gql(
- """
- mutation updateLaunchAgent($agentId: ID!, $agentStatus: String){
- updateLaunchAgent(
- input: {
- launchAgentId: $agentId
- agentStatus: $agentStatus
- }
- ) {
- success
- }
- }
- """
- )
- variable_values = {
- "agentId": agent_id,
- "agentStatus": status,
- }
- result: dict = self.gql(mutation, variable_values)["updateLaunchAgent"]
- return result
- @normalize_exceptions
- def get_launch_agent(self, agent_id: str) -> dict:
- query = gql(
- """
- query LaunchAgent($agentId: ID!) {
- launchAgent(id: $agentId) {
- id
- name
- runQueues
- hostname
- agentStatus
- stopPolling
- heartbeatAt
- }
- }
- """
- )
- variable_values = {
- "agentId": agent_id,
- }
- result: dict = self.gql(query, variable_values)["launchAgent"]
- return result
- @normalize_exceptions
- def upsert_run(
- self,
- id: str | None = None,
- name: str | None = None,
- project: str | None = None,
- host: str | None = None,
- group: str | None = None,
- tags: list[str] | None = None,
- config: dict | None = None,
- description: str | None = None,
- entity: str | None = None,
- state: str | None = None,
- display_name: str | None = None,
- notes: str | None = None,
- repo: str | None = None,
- job_type: str | None = None,
- program_path: str | None = None,
- commit: str | None = None,
- sweep_name: str | None = None,
- summary_metrics: str | None = None,
- num_retries: int | None = None,
- ) -> tuple[dict, bool]:
- """Update a run.
- Args:
- id (str, optional): The existing run to update
- name (str, optional): The name of the run to create
- group (str, optional): Name of the group this run is a part of
- project (str, optional): The name of the project
- host (str, optional): The name of the host
- tags (list, optional): A list of tags to apply to the run
- config (dict, optional): The latest config params
- description (str, optional): A description of this project
- entity (str, optional): The entity to scope this project to.
- display_name (str, optional): The display name of this project
- notes (str, optional): Notes about this run
- repo (str, optional): Url of the program's repository.
- state (str, optional): State of the program.
- job_type (str, optional): Type of job, e.g 'train'.
- program_path (str, optional): Path to the program.
- commit (str, optional): The Git SHA to associate the run with
- sweep_name (str, optional): The name of the sweep this run is a part of
- summary_metrics (str, optional): The JSON summary metrics
- num_retries (int, optional): Number of retries
- """
- query_string = """
- mutation UpsertBucket(
- $id: String,
- $name: String,
- $project: String,
- $entity: String,
- $groupName: String,
- $description: String,
- $displayName: String,
- $notes: String,
- $commit: String,
- $config: JSONString,
- $host: String,
- $debug: Boolean,
- $program: String,
- $repo: String,
- $jobType: String,
- $state: String,
- $sweep: String,
- $tags: [String!],
- $summaryMetrics: JSONString,
- ) {
- upsertBucket(input: {
- id: $id,
- name: $name,
- groupName: $groupName,
- modelName: $project,
- entityName: $entity,
- description: $description,
- displayName: $displayName,
- notes: $notes,
- config: $config,
- commit: $commit,
- host: $host,
- debug: $debug,
- jobProgram: $program,
- jobRepo: $repo,
- jobType: $jobType,
- state: $state,
- sweep: $sweep,
- tags: $tags,
- summaryMetrics: $summaryMetrics,
- }) {
- bucket {
- id
- name
- displayName
- description
- config
- sweepName
- project {
- id
- name
- entity {
- id
- name
- }
- }
- historyLineCount
- }
- inserted
- }
- }
- """
- mutation = gql(query_string)
- config_str = json.dumps(config) if config else None
- if not description or description.isspace():
- description = None
- kwargs = {}
- if num_retries is not None:
- kwargs["num_retries"] = num_retries
- variable_values = {
- "id": id,
- "entity": entity or self.settings("entity"),
- "name": name,
- "project": project or util.auto_project_name(program_path),
- "groupName": group,
- "tags": tags,
- "description": description,
- "config": config_str,
- "commit": commit,
- "displayName": display_name,
- "notes": notes,
- "host": None
- if self.settings().get("anonymous") in ["allow", "must"]
- else host,
- "debug": env.is_debug(env=self._environ),
- "repo": repo,
- "program": program_path,
- "jobType": job_type,
- "state": state,
- "sweep": sweep_name,
- "summaryMetrics": summary_metrics,
- }
- # retry conflict errors for 2 minutes, default to no_auth_retry
- check_retry_fn = util.make_check_retry_fn(
- check_fn=util.check_retry_conflict_or_gone,
- check_timedelta=datetime.timedelta(minutes=2),
- fallback_retry_fn=util.no_retry_auth,
- )
- response = self.gql(
- mutation,
- variable_values=variable_values,
- check_retry_fn=check_retry_fn,
- **kwargs,
- )
- run_obj: dict[str, dict[str, dict[str, str]]] = response["upsertBucket"][
- "bucket"
- ]
- project_obj: dict[str, dict[str, str]] = run_obj.get("project", {})
- if project_obj:
- self.set_setting("project", project_obj["name"])
- entity_obj = project_obj.get("entity", {})
- if entity_obj:
- self.set_setting("entity", entity_obj["name"])
- return (
- response["upsertBucket"]["bucket"],
- response["upsertBucket"]["inserted"],
- )
- @normalize_exceptions
- def rewind_run(
- self,
- run_name: str,
- metric_name: str,
- metric_value: float,
- program_path: str | None = None,
- entity: str | None = None,
- project: str | None = None,
- num_retries: int | None = None,
- ) -> dict:
- """Rewinds a run to a previous state.
- Args:
- run_name (str): The name of the run to rewind
- metric_name (str): The name of the metric to rewind to
- metric_value (float): The value of the metric to rewind to
- program_path (str, optional): Path to the program
- entity (str, optional): The entity to scope this project to
- project (str, optional): The name of the project
- num_retries (int, optional): Number of retries
- Returns:
- A dict with the rewound run
- {
- "id": "run_id",
- "name": "run_name",
- "displayName": "run_display_name",
- "description": "run_description",
- "config": "stringified_run_config_json",
- "sweepName": "run_sweep_name",
- "project": {
- "id": "project_id",
- "name": "project_name",
- "entity": {
- "id": "entity_id",
- "name": "entity_name"
- }
- },
- "historyLineCount": 100,
- }
- """
- query_string = """
- mutation RewindRun($runName: String!, $entity: String, $project: String, $metricName: String!, $metricValue: Float!) {
- rewindRun(input: {runName: $runName, entityName: $entity, projectName: $project, metricName: $metricName, metricValue: $metricValue}) {
- rewoundRun {
- id
- name
- displayName
- description
- config
- sweepName
- project {
- id
- name
- entity {
- id
- name
- }
- }
- historyLineCount
- }
- }
- }
- """
- mutation = gql(query_string)
- kwargs = {}
- if num_retries is not None:
- kwargs["num_retries"] = num_retries
- variable_values = {
- "runName": run_name,
- "entity": entity or self.settings("entity"),
- "project": project or util.auto_project_name(program_path),
- "metricName": metric_name,
- "metricValue": metric_value,
- }
- # retry conflict errors for 2 minutes, default to no_auth_retry
- check_retry_fn = util.make_check_retry_fn(
- check_fn=util.check_retry_conflict_or_gone,
- check_timedelta=datetime.timedelta(minutes=2),
- fallback_retry_fn=util.no_retry_auth,
- )
- response = self.gql(
- mutation,
- variable_values=variable_values,
- check_retry_fn=check_retry_fn,
- **kwargs,
- )
- run_obj: dict[str, dict[str, dict[str, str]]] = response.get(
- "rewindRun", {}
- ).get("rewoundRun", {})
- project_obj: dict[str, dict[str, str]] = run_obj.get("project", {})
- if project_obj:
- self.set_setting("project", project_obj["name"])
- entity_obj = project_obj.get("entity", {})
- if entity_obj:
- self.set_setting("entity", entity_obj["name"])
- return run_obj
- @normalize_exceptions
- def get_run_info(
- self,
- entity: str,
- project: str,
- name: str,
- ) -> dict:
- query = gql(
- """
- query RunInfo($project: String!, $entity: String!, $name: String!) {
- project(name: $project, entityName: $entity) {
- run(name: $name) {
- runInfo {
- program
- args
- os
- python
- colab
- executable
- codeSaved
- cpuCount
- gpuCount
- gpu
- git {
- remote
- commit
- }
- }
- }
- }
- }
- """
- )
- variable_values = {"project": project, "entity": entity, "name": name}
- res = self.gql(query, variable_values)
- if res.get("project") is None:
- raise CommError(
- f"Error fetching run info for {entity}/{project}/{name}. Check that this project exists and you have access to this entity and project"
- )
- elif res["project"].get("run") is None:
- raise CommError(
- f"Error fetching run info for {entity}/{project}/{name}. Check that this run id exists"
- )
- run_info: dict = res["project"]["run"]["runInfo"]
- return run_info
- @normalize_exceptions
- def get_run_state(self, entity: str, project: str, name: str) -> str:
- query = gql(
- """
- query RunState(
- $project: String!,
- $entity: String!,
- $name: String!) {
- project(name: $project, entityName: $entity) {
- run(name: $name) {
- state
- }
- }
- }
- """
- )
- variable_values = {
- "project": project,
- "entity": entity,
- "name": name,
- }
- res = self.gql(query, variable_values)
- if res.get("project") is None or res["project"].get("run") is None:
- raise CommError(f"Error fetching run state for {entity}/{project}/{name}.")
- run_state: str = res["project"]["run"]["state"]
- return run_state
- @normalize_exceptions
- def upload_urls(
- self,
- project: str,
- files: list[str] | dict[str, IO],
- run: str | None = None,
- entity: str | None = None,
- description: str | None = None,
- ) -> tuple[str, list[str], dict[str, dict[str, Any]]]:
- """Generate temporary resumable upload urls.
- Args:
- project (str): The project to download
- files (list or dict): The filenames to upload
- run (str, optional): The run to upload to
- entity (str, optional): The entity to scope this project to.
- description (str, optional): description
- Returns:
- (run_id, upload_headers, file_info)
- run_id: id of run we uploaded files to
- upload_headers: A list of headers to use when uploading files.
- file_info: A dict of filenames and urls.
- {
- "run_id": "run_id",
- "upload_headers": [""],
- "file_info": [
- { "weights.h5": { "uploadUrl": "https://weights.url" } },
- { "model.json": { "uploadUrl": "https://model.json" } }
- ]
- }
- """
- run_name = run or self.current_run_id
- assert run_name, "run must be specified"
- entity = entity or self.settings("entity")
- assert entity, "entity must be specified"
- query = gql(
- """
- mutation CreateRunFiles($entity: String!, $project: String!, $run: String!, $files: [String!]!) {
- createRunFiles(input: {entityName: $entity, projectName: $project, runName: $run, files: $files}) {
- runID
- uploadHeaders
- files {
- name
- uploadUrl
- }
- }
- }
- """
- )
- query_result = self.gql(
- query,
- variable_values={
- "project": project,
- "run": run_name,
- "entity": entity,
- "files": [file for file in files],
- },
- )
- result = query_result["createRunFiles"]
- run_id = result["runID"]
- if not run_id:
- raise CommError(
- f"Error uploading files to {entity}/{project}/{run_name}. Check that this project exists and you have access to this entity and project"
- )
- file_name_urls = {file["name"]: file for file in result["files"]}
- return run_id, result["uploadHeaders"], file_name_urls
- def legacy_upload_urls(
- self,
- project: str,
- files: list[str] | dict[str, IO],
- run: str | None = None,
- entity: str | None = None,
- description: str | None = None,
- ) -> tuple[str, list[str], dict[str, dict[str, Any]]]:
- """Generate temporary resumable upload urls.
- A new mutation createRunFiles was introduced after 0.15.4.
- This function is used to support older versions.
- """
- query = gql(
- """
- query RunUploadUrls($name: String!, $files: [String]!, $entity: String, $run: String!, $description: String) {
- model(name: $name, entityName: $entity) {
- bucket(name: $run, desc: $description) {
- id
- files(names: $files) {
- uploadHeaders
- edges {
- node {
- name
- url(upload: true)
- updatedAt
- }
- }
- }
- }
- }
- }
- """
- )
- run_id = run or self.current_run_id
- assert run_id, "run must be specified"
- entity = entity or self.settings("entity")
- query_result = self.gql(
- query,
- variable_values={
- "name": project,
- "run": run_id,
- "entity": entity,
- "files": [file for file in files],
- "description": description,
- },
- )
- run_obj = query_result["model"]["bucket"]
- if run_obj:
- for file_node in run_obj["files"]["edges"]:
- file = file_node["node"]
- # we previously used "url" field but now use "uploadUrl"
- # replace the "url" field with "uploadUrl for downstream compatibility
- if "url" in file and "uploadUrl" not in file:
- file["uploadUrl"] = file.pop("url")
- result = {
- file["name"]: file for file in self._flatten_edges(run_obj["files"])
- }
- return run_obj["id"], run_obj["files"]["uploadHeaders"], result
- else:
- raise CommError(f"Run does not exist {entity}/{project}/{run_id}.")
- @normalize_exceptions
- def download_urls(
- self,
- project: str,
- run: str | None = None,
- entity: str | None = None,
- ) -> dict[str, dict[str, str]]:
- """Generate download urls.
- Args:
- project (str): The project to download
- run (str): The run to upload to
- entity (str, optional): The entity to scope this project to. Defaults to wandb models
- Returns:
- A dict of extensions and urls
- {
- 'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
- 'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
- }
- """
- query = gql(
- """
- query RunDownloadUrls($name: String!, $entity: String, $run: String!) {
- model(name: $name, entityName: $entity) {
- bucket(name: $run) {
- files {
- edges {
- node {
- name
- url
- md5
- updatedAt
- }
- }
- }
- }
- }
- }
- """
- )
- run = run or self.current_run_id
- assert run, "run must be specified"
- entity = entity or self.settings("entity")
- query_result = self.gql(
- query,
- variable_values={
- "name": project,
- "run": run,
- "entity": entity,
- },
- )
- if query_result["model"] is None:
- raise CommError(f"Run does not exist {entity}/{project}/{run}.")
- files = self._flatten_edges(query_result["model"]["bucket"]["files"])
- return {file["name"]: file for file in files if file}
- @normalize_exceptions
- def download_url(
- self,
- project: str,
- file_name: str,
- run: str | None = None,
- entity: str | None = None,
- ) -> dict[str, str] | None:
- """Generate download urls.
- Args:
- project (str): The project to download
- file_name (str): The name of the file to download
- run (str): The run to upload to
- entity (str, optional): The entity to scope this project to. Defaults to wandb models
- Returns:
- A dict of extensions and urls
- { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
- """
- query = gql(
- """
- query RunDownloadUrl($name: String!, $fileName: String!, $entity: String, $run: String!) {
- model(name: $name, entityName: $entity) {
- bucket(name: $run) {
- files(names: [$fileName]) {
- edges {
- node {
- name
- url
- md5
- updatedAt
- }
- }
- }
- }
- }
- }
- """
- )
- run = run or self.current_run_id
- assert run, "run must be specified"
- query_result = self.gql(
- query,
- variable_values={
- "name": project,
- "run": run,
- "fileName": file_name,
- "entity": entity or self.settings("entity"),
- },
- )
- if query_result["model"]:
- files = self._flatten_edges(query_result["model"]["bucket"]["files"])
- return files[0] if len(files) > 0 and files[0].get("updatedAt") else None
- else:
- return None
- @normalize_exceptions
- def download_file(self, url: str) -> tuple[int, requests.Response]:
- """Initiate a streaming download.
- Args:
- url (str): The url to download
- Returns:
- A tuple of the content length and the streaming response
- """
- import requests
- check_httpclient_logger_handler()
- http_headers = {}
- auth = None
- if self.access_token is not None:
- http_headers["Authorization"] = f"Bearer {self.access_token}"
- else:
- auth = ("api", self.api_key or "")
- response = requests.get(
- url,
- auth=auth,
- headers=http_headers,
- stream=True,
- )
- response.raise_for_status()
- return int(response.headers.get("content-length", 0)), response
- @normalize_exceptions
- def download_write_file(
- self,
- metadata: dict[str, str],
- out_dir: str | None = None,
- ) -> tuple[str, requests.Response | None]:
- """Download a file from a run and write it to wandb/.
- Args:
- metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().
- out_dir (str, optional): The directory to write the file to. Defaults to wandb/
- Returns:
- A tuple of the file's local path and the streaming response. The streaming response is None if the file
- already existed and was up-to-date.
- """
- filename = metadata["name"]
- path = os.path.join(out_dir or self.settings("wandb_dir"), filename)
- if self.file_current(filename, B64MD5(metadata["md5"])):
- return path, None
- size, response = self.download_file(metadata["url"])
- with util.fsync_open(path, "wb") as file:
- for data in response.iter_content(chunk_size=1024):
- file.write(data)
- return path, response
- def upload_file_azure(
- self, url: str, file: Any, extra_headers: dict[str, str]
- ) -> None:
- """Upload a file to azure."""
- import requests
- from azure.core.exceptions import AzureError # type: ignore
- # Configure the client without retries so our existing logic can handle them
- client = self._azure_blob_module.BlobClient.from_blob_url(
- url, retry_policy=self._azure_blob_module.LinearRetry(retry_total=0)
- )
- try:
- if extra_headers.get("Content-MD5") is not None:
- md5: bytes | None = base64.b64decode(extra_headers["Content-MD5"])
- else:
- md5 = None
- content_settings = self._azure_blob_module.ContentSettings(
- content_md5=md5,
- content_type=extra_headers.get("Content-Type"),
- )
- client.upload_blob(
- file,
- max_concurrency=4,
- length=len(file),
- overwrite=True,
- content_settings=content_settings,
- )
- except AzureError as e:
- if hasattr(e, "response"):
- response = requests.models.Response()
- response.status_code = e.response.status_code
- response.headers = e.response.headers
- raise requests.exceptions.RequestException(e.message, response=response)
- else:
- raise requests.exceptions.ConnectionError(e.message)
- def upload_multipart_file_chunk(
- self,
- url: str,
- upload_chunk: bytes,
- extra_headers: dict[str, str] | None = None,
- ) -> requests.Response | None:
- """Upload a file chunk to S3 with failure resumption.
- Args:
- url: The url to download
- upload_chunk: The path to the file you want to upload
- extra_headers: A dictionary of extra headers to send with the request
- Returns:
- The `requests` library response object
- """
- import requests
- check_httpclient_logger_handler()
- try:
- if env.is_debug(env=self._environ):
- logger.debug("upload_file: %s", url)
- response = self._upload_file_session.put(
- url, data=upload_chunk, headers=extra_headers
- )
- if env.is_debug(env=self._environ):
- logger.debug("upload_file: %s complete", url)
- response.raise_for_status()
- except requests.exceptions.RequestException as e:
- logger.exception(f"upload_file exception for {url=}")
- response_content = e.response.content if e.response is not None else ""
- status_code = e.response.status_code if e.response is not None else 0
- # S3 reports retryable request timeouts out-of-band
- is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
- response_content
- )
- # Retry errors from cloud storage or local network issues
- if (
- status_code in (308, 408, 409, 429, 500, 502, 503, 504)
- or isinstance(
- e,
- (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
- )
- or is_aws_retryable
- ):
- _e = retry.TransientError(exc=e)
- raise _e.with_traceback(sys.exc_info()[2])
- else:
- get_sentry().reraise(e)
- return response
- def upload_file(
- self,
- url: str,
- file: IO[bytes],
- callback: ProgressFn | None = None,
- extra_headers: dict[str, str] | None = None,
- ) -> requests.Response | None:
- """Upload a file to W&B with failure resumption.
- Args:
- url: The url to download
- file: The path to the file you want to upload
- callback: A callback which is passed the number of
- bytes uploaded since the last time it was called, used to report progress
- extra_headers: A dictionary of extra headers to send with the request
- Returns:
- The `requests` library response object
- """
- import requests
- check_httpclient_logger_handler()
- extra_headers = extra_headers.copy() if extra_headers else {}
- response: requests.Response | None = None
- progress = Progress(file, callback=callback)
- try:
- if "x-ms-blob-type" in extra_headers and self._azure_blob_module:
- self.upload_file_azure(url, progress, extra_headers)
- else:
- if "x-ms-blob-type" in extra_headers:
- wandb.termwarn(
- "Azure uploads over 256MB require the azure SDK, install with pip install wandb[azure]",
- repeat=False,
- )
- if env.is_debug(env=self._environ):
- logger.debug("upload_file: %s", url)
- response = self._upload_file_session.put(
- url, data=progress, headers=extra_headers
- )
- if env.is_debug(env=self._environ):
- logger.debug("upload_file: %s complete", url)
- response.raise_for_status()
- except requests.exceptions.RequestException as e:
- logger.exception(f"upload_file exception for {url=}")
- response_content = e.response.content if e.response is not None else ""
- status_code = e.response.status_code if e.response is not None else 0
- # S3 reports retryable request timeouts out-of-band
- is_aws_retryable = (
- "x-amz-meta-md5" in extra_headers
- and status_code == 400
- and "RequestTimeout" in str(response_content)
- )
- # We need to rewind the file for the next retry (the file passed in is `seek`'ed to 0)
- progress.rewind()
- # Retry errors from cloud storage or local network issues
- if (
- status_code in (308, 408, 409, 429, 500, 502, 503, 504)
- or isinstance(
- e,
- (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
- )
- or is_aws_retryable
- ):
- _e = retry.TransientError(exc=e)
- raise _e.with_traceback(sys.exc_info()[2])
- else:
- get_sentry().reraise(e)
- return response
- @normalize_exceptions
- def register_agent(
- self,
- host: str,
- sweep_id: str | None = None,
- project_name: str | None = None,
- entity: str | None = None,
- ) -> dict:
- """Register a new agent.
- Args:
- host (str): hostname
- sweep_id (str): sweep id
- project_name: (str): model that contains sweep
- entity: (str): entity that contains sweep
- """
- mutation = gql(
- """
- mutation CreateAgent(
- $host: String!
- $projectName: String,
- $entityName: String,
- $sweep: String!
- ) {
- createAgent(input: {
- host: $host,
- projectName: $projectName,
- entityName: $entityName,
- sweep: $sweep,
- }) {
- agent {
- id
- }
- }
- }
- """
- )
- if entity is None:
- entity = self.settings("entity")
- if project_name is None:
- project_name = self.settings("project")
- response = self.gql(
- mutation,
- variable_values={
- "host": host,
- "entityName": entity,
- "projectName": project_name,
- "sweep": sweep_id,
- },
- check_retry_fn=util.no_retry_4xx,
- )
- result: dict = response["createAgent"]["agent"]
- return result
- def agent_heartbeat(
- self, agent_id: str, metrics: dict, run_states: dict
- ) -> list[dict[str, Any]]:
- """Notify server about agent state, receive commands.
- Args:
- agent_id (str): agent_id
- metrics (dict): system metrics
- run_states (dict): run_id: state mapping
- Returns:
- list of commands to execute.
- Raises:
- SweepNotFoundError: If the server returns a 404, indicating the
- sweep was likely deleted.
- """
- import requests
- from wandb.sdk.launch.sweeps import SweepNotFoundError
- mutation = gql(
- """
- mutation Heartbeat(
- $id: ID!,
- $metrics: JSONString,
- $runState: JSONString
- ) {
- agentHeartbeat(input: {
- id: $id,
- metrics: $metrics,
- runState: $runState
- }) {
- agent {
- id
- }
- commands
- }
- }
- """
- )
- if agent_id is None:
- raise ValueError("Cannot call heartbeat with an unregistered agent.")
- try:
- response = self.gql(
- mutation,
- variable_values={
- "id": agent_id,
- "metrics": json.dumps(metrics),
- "runState": json.dumps(run_states),
- },
- timeout=60,
- )
- except requests.exceptions.HTTPError as e:
- if e.response is not None and e.response.status_code == 404:
- raise SweepNotFoundError(
- "Sweep not found. The sweep may have been deleted."
- ) from e
- logger.exception("Error communicating with W&B.")
- return []
- except Exception:
- logger.exception("Error communicating with W&B.")
- return []
- else:
- result: list[dict[str, Any]] = json.loads(
- response["agentHeartbeat"]["commands"]
- )
- return result
- @staticmethod
- def _validate_config_and_fill_distribution(config: dict) -> dict:
- # verify that parameters are well specified.
- # TODO(dag): deprecate this in favor of jsonschema validation once
- # apiVersion 2 is released and local controller is integrated with
- # wandb/client.
- # avoid modifying the original config dict in
- # case it is reused outside the calling func
- config = deepcopy(config)
- # explicitly cast to dict in case config was passed as a sweepconfig
- # sweepconfig does not serialize cleanly to yaml and breaks graphql,
- # but it is a subclass of dict, so this conversion is clean
- config = dict(config)
- if "parameters" not in config:
- # still shows an anaconda warning, but doesn't error
- return config
- for parameter_name in config["parameters"]:
- parameter = config["parameters"][parameter_name]
- if (
- "min" in parameter
- and "max" in parameter
- and "distribution" not in parameter
- ):
- if isinstance(parameter["min"], int) and isinstance(
- parameter["max"], int
- ):
- parameter["distribution"] = "int_uniform"
- elif isinstance(parameter["min"], float) and isinstance(
- parameter["max"], float
- ):
- parameter["distribution"] = "uniform"
- else:
- raise ValueError(
- f"Parameter {parameter_name} is ambiguous, please specify bounds as both floats (for a float_"
- "uniform distribution) or ints (for an int_uniform distribution)."
- )
- return config
- @normalize_exceptions
- def upsert_sweep(
- self,
- config: dict,
- controller: str | None = None,
- launch_scheduler: str | None = None,
- scheduler: str | None = None,
- obj_id: str | None = None,
- project: str | None = None,
- entity: str | None = None,
- state: str | None = None,
- prior_runs: list[str] | None = None,
- display_name: str | None = None,
- template_variable_values: dict[str, Any] | None = None,
- ) -> tuple[str, list[str]]:
- """Upsert a sweep object.
- Args:
- config (dict): sweep config (will be converted to yaml)
- controller (str): controller to use
- launch_scheduler (str): launch scheduler to use
- scheduler (str): scheduler to use
- obj_id (str): object id
- project (str): project to use
- entity (str): entity to use
- state (str): state
- prior_runs (list): IDs of existing runs to add to the sweep
- display_name (str): display name for the sweep
- template_variable_values (dict): template variable values
- """
- import yaml
- project_query = """
- project {
- id
- name
- entity {
- id
- name
- }
- }
- """
- mutation_str = """
- mutation UpsertSweep(
- $id: ID,
- $config: String,
- $description: String,
- $entityName: String,
- $projectName: String,
- $controller: JSONString,
- $scheduler: JSONString,
- $state: String,
- $priorRunsFilters: JSONString,
- $displayName: String,
- ) {
- upsertSweep(input: {
- id: $id,
- config: $config,
- description: $description,
- entityName: $entityName,
- projectName: $projectName,
- controller: $controller,
- scheduler: $scheduler,
- state: $state,
- priorRunsFilters: $priorRunsFilters,
- displayName: $displayName,
- }) {
- sweep {
- name
- _PROJECT_QUERY_
- }
- configValidationWarnings
- }
- }
- """
- # TODO(jhr): we need protocol versioning to know schema is not supported
- # for now we will just try both new and old query
- mutation_5 = gql(
- mutation_str.replace(
- "$controller: JSONString,",
- "$controller: JSONString,$launchScheduler: JSONString, $templateVariableValues: JSONString,",
- )
- .replace(
- "controller: $controller,",
- "controller: $controller,launchScheduler: $launchScheduler,templateVariableValues: $templateVariableValues,",
- )
- .replace("_PROJECT_QUERY_", project_query)
- )
- # launchScheduler was introduced in core v0.14.0
- mutation_4 = gql(
- mutation_str.replace(
- "$controller: JSONString,",
- "$controller: JSONString,$launchScheduler: JSONString,",
- )
- .replace(
- "controller: $controller,",
- "controller: $controller,launchScheduler: $launchScheduler",
- )
- .replace("_PROJECT_QUERY_", project_query)
- )
- # mutation 3 maps to backend that can support CLI version of at least 0.10.31
- mutation_3 = gql(mutation_str.replace("_PROJECT_QUERY_", project_query))
- mutation_2 = gql(
- mutation_str.replace("_PROJECT_QUERY_", project_query).replace(
- "configValidationWarnings", ""
- )
- )
- mutation_1 = gql(
- mutation_str.replace("_PROJECT_QUERY_", "").replace(
- "configValidationWarnings", ""
- )
- )
- # TODO(dag): replace this with a query for protocol versioning
- mutations = [mutation_5, mutation_4, mutation_3, mutation_2, mutation_1]
- config = self._validate_config_and_fill_distribution(config)
- # Silly, but attr-dicts like Easydicts don't serialize correctly to yaml.
- # This sanitizes them with a round trip pass through json to get a regular dict.
- class NonOctalStringDumper(yaml.Dumper):
- """Prevents strings containing non-octal values like "008" and "009" from being converted to numbers in in the yaml string saved as the sweep config."""
- def represent_scalar(self, tag, value, style=None):
- if (
- tag == "tag:yaml.org,2002:str"
- and value.startswith("0")
- and len(value) > 1
- ):
- return super().represent_scalar(tag, value, style="'")
- return super().represent_scalar(tag, value, style)
- config_str = yaml.dump(
- json.loads(json.dumps(config)), Dumper=NonOctalStringDumper
- )
- filters = None
- if prior_runs:
- filters = json.dumps({"$or": [{"name": r} for r in prior_runs]})
- err: Exception | None = None
- for mutation in mutations:
- try:
- variables = {
- "id": obj_id,
- "config": config_str,
- "description": config.get("description"),
- "entityName": entity or self.settings("entity"),
- "projectName": project or self.settings("project"),
- "controller": controller,
- "launchScheduler": launch_scheduler,
- "templateVariableValues": json.dumps(template_variable_values),
- "scheduler": scheduler,
- "priorRunsFilters": filters,
- "displayName": display_name,
- }
- if state:
- variables["state"] = state
- response = self.gql(
- mutation,
- variable_values=variables,
- check_retry_fn=util.no_retry_4xx,
- )
- except UsageError:
- raise
- except Exception as e:
- # graphql schema exception is generic
- err = e
- continue
- err = None
- break
- if err:
- raise err
- sweep: dict[str, dict[str, dict]] = response["upsertSweep"]["sweep"]
- project_obj: dict[str, dict] = sweep.get("project", {})
- if project_obj:
- self.set_setting("project", project_obj["name"])
- entity_obj: dict = project_obj.get("entity", {})
- if entity_obj:
- self.set_setting("entity", entity_obj["name"])
- warnings = response["upsertSweep"].get("configValidationWarnings", [])
- return response["upsertSweep"]["sweep"]["name"], warnings
- @staticmethod
- def file_current(fname: str, md5: B64MD5) -> bool:
- """Checksum a file and compare the md5 with the known md5."""
- return os.path.isfile(fname) and md5_file_b64(fname) == md5
- @normalize_exceptions
- def pull(
- self, project: str, run: str | None = None, entity: str | None = None
- ) -> list[requests.Response]:
- """Download files from W&B.
- Args:
- project (str): The project to download
- run (str, optional): The run to upload to
- entity (str, optional): The entity to scope this project to. Defaults to wandb models
- Returns:
- The `requests` library response object
- """
- project, run = self.parse_slug(project, run=run)
- urls = self.download_urls(project, run, entity)
- responses = []
- for filename in urls:
- _, response = self.download_write_file(urls[filename])
- if response:
- responses.append(response)
- return responses
- def get_project(self) -> str:
- project: str = self.default_settings.get("project") or self.settings("project")
- return project
- @normalize_exceptions
- def push(
- self,
- files: list[str] | dict[str, IO],
- run: str | None = None,
- entity: str | None = None,
- project: str | None = None,
- description: str | None = None,
- force: bool = True,
- progress: TextIO | Literal[False] = False,
- ) -> list[requests.Response | None]:
- """Uploads multiple files to W&B.
- Args:
- files (list or dict): The filenames to upload, when dict the values are open files
- run (str, optional): The run to upload to
- entity (str, optional): The entity to scope this project to. Defaults to wandb models
- project (str, optional): The name of the project to upload to. Defaults to the one in settings.
- description (str, optional): The description of the changes
- force (bool, optional): Whether to prevent push if git has uncommitted changes
- progress (callable, or stream): If callable, will be called with (chunk_bytes,
- total_bytes) as argument. If TextIO, renders a progress bar to it.
- Returns:
- A list of `requests.Response` objects
- """
- if project is None:
- project = self.get_project()
- if project is None:
- raise CommError("No project configured.")
- if run is None:
- run = self.current_run_id
- # TODO(adrian): we use a retriable version of self.upload_file() so
- # will never retry self.upload_urls() here. Instead, maybe we should
- # make push itself retriable.
- _, upload_headers, result = self.upload_urls(
- project,
- files,
- run,
- entity,
- )
- extra_headers = {}
- for upload_header in upload_headers:
- key, val = upload_header.split(":", 1)
- extra_headers[key] = val
- responses = []
- for file_name, file_info in result.items():
- file_url = file_info["uploadUrl"]
- # If the upload URL is relative, fill it in with the base URL,
- # since it's a proxied file store like the on-prem VM.
- if file_url.startswith("/"):
- file_url = f"{self.api_url}{file_url}"
- try:
- # To handle Windows paths
- # TODO: this doesn't handle absolute paths...
- normal_name = os.path.join(*file_name.split("/"))
- open_file = (
- files[file_name]
- if isinstance(files, dict)
- else open(normal_name, "rb")
- )
- except OSError:
- print(f"{file_name} does not exist") # noqa: T201
- continue
- if progress is False:
- responses.append(
- self.upload_file_retry(
- file_info["uploadUrl"], open_file, extra_headers=extra_headers
- )
- )
- else:
- if callable(progress):
- responses.append( # type: ignore
- self.upload_file_retry(
- file_url, open_file, progress, extra_headers=extra_headers
- )
- )
- else:
- length = os.fstat(open_file.fileno()).st_size
- with click.progressbar( # type: ignore
- file=progress,
- length=length,
- label=f"Uploading file: {file_name}",
- fill_char=click.style("&", fg="green"),
- ) as bar:
- responses.append(
- self.upload_file_retry(
- file_url,
- open_file,
- lambda bites, _: bar.update(bites),
- extra_headers=extra_headers,
- )
- )
- open_file.close()
- return responses
- def link_artifact(
- self,
- client_id: str,
- server_id: str,
- portfolio_name: str,
- entity: str,
- project: str,
- aliases: Sequence[str],
- organization: str,
- ) -> dict[str, Any]:
- from wandb.sdk.artifacts._validators import is_artifact_registry_project
- template = """
- mutation LinkArtifact(
- $artifactPortfolioName: String!,
- $entityName: String!,
- $projectName: String!,
- $aliases: [ArtifactAliasInput!],
- ID_TYPE
- ) {
- linkArtifact(input: {
- artifactPortfolioName: $artifactPortfolioName,
- entityName: $entityName,
- projectName: $projectName,
- aliases: $aliases,
- ID_VALUE
- }) {
- versionIndex
- }
- }
- """
- org_entity = ""
- if is_artifact_registry_project(project):
- try:
- org_entity = self._resolve_org_entity_name(
- entity=entity, organization=organization
- )
- except ValueError as e:
- wandb.termerror(str(e))
- raise
- def replace(a: str, b: str) -> None:
- nonlocal template
- template = template.replace(a, b)
- if server_id:
- replace("ID_TYPE", "$artifactID: ID")
- replace("ID_VALUE", "artifactID: $artifactID")
- elif client_id:
- replace("ID_TYPE", "$clientID: ID")
- replace("ID_VALUE", "clientID: $clientID")
- variable_values = {
- "clientID": client_id,
- "artifactID": server_id,
- "artifactPortfolioName": portfolio_name,
- "entityName": org_entity or entity,
- "projectName": project,
- "aliases": [
- {"alias": alias, "artifactCollectionName": portfolio_name}
- for alias in aliases
- ],
- }
- mutation = gql(template)
- response = self.gql(mutation, variable_values=variable_values)
- link_artifact: dict[str, Any] = response["linkArtifact"]
- return link_artifact
- def _resolve_org_entity_name(self, entity: str, organization: str = "") -> str:
- # resolveOrgEntityName fetches the portfolio's org entity's name.
- #
- # The organization parameter may be empty, an org's display name, or an org entity name.
- #
- # If the server doesn't support fetching the org name of a portfolio, then this returns
- # the organization parameter, or an error if it is empty. Otherwise, this returns the
- # fetched value after validating that the given organization, if not empty, matches
- # either the org's display or entity name.
- if not entity:
- raise ValueError("Entity name is required to resolve org entity name.")
- orgs_from_entity = self._fetch_orgs_and_org_entities_from_entity(entity)
- if organization:
- return _match_org_with_fetched_org_entities(organization, orgs_from_entity)
- # If no input organization provided, error if entity belongs to multiple orgs because we
- # cannot determine which one to use.
- if len(orgs_from_entity) > 1:
- raise ValueError(
- f"Personal entity {entity!r} belongs to multiple organizations "
- "and cannot be used without specifying the organization name. "
- "Please specify the organization in the Registry path or use a team entity in the entity settings."
- )
- return orgs_from_entity[0].entity_name
- def _fetch_orgs_and_org_entities_from_entity(self, entity: str) -> list[_OrgNames]:
- """Fetches organization entity names and display names for a given entity.
- Args:
- entity (str): Entity name to lookup. Can be either a personal or team entity.
- Returns:
- list[_OrgNames]: list of _OrgNames tuples. (_OrgNames(entity_name, display_name))
- Raises:
- ValueError: If entity is not found, has no organizations, or other validation errors.
- """
- query = gql(
- """
- query FetchOrgEntityFromEntity($entityName: String!) {
- entity(name: $entityName) {
- organization {
- name
- orgEntity {
- name
- }
- }
- user {
- organizations {
- name
- orgEntity {
- name
- }
- }
- }
- }
- }
- """
- )
- response = self.gql(
- query,
- variable_values={
- "entityName": entity,
- },
- )
- # Parse organization from response
- entity_resp = response["entity"]["organization"]
- user_resp = response["entity"]["user"]
- # Check for organization under team/org entity type
- if entity_resp:
- org_name = entity_resp.get("name")
- org_entity_name = entity_resp.get("orgEntity") and entity_resp[
- "orgEntity"
- ].get("name")
- if not org_name or not org_entity_name:
- raise ValueError(
- f"Unable to find an organization under entity {entity!r}."
- )
- return [_OrgNames(entity_name=org_entity_name, display_name=org_name)]
- # Check for organization under personal entity type, where a user can belong to multiple orgs
- elif user_resp:
- orgs = user_resp.get("organizations", [])
- org_entities_return = [
- _OrgNames(
- entity_name=org["orgEntity"]["name"], display_name=org["name"]
- )
- for org in orgs
- if org.get("orgEntity") and org.get("name")
- ]
- if not org_entities_return:
- raise ValueError(
- f"Unable to resolve an organization associated with personal entity: {entity!r}. "
- "This could be because its a personal entity that doesn't belong to any organizations. "
- "Please specify the organization in the Registry path or use a team entity in the entity settings."
- )
- return org_entities_return
- else:
- raise ValueError(f"Unable to find an organization under entity {entity!r}.")
- def _construct_use_artifact_query(
- self,
- artifact_id: str,
- entity_name: str | None = None,
- project_name: str | None = None,
- run_name: str | None = None,
- use_as: str | None = None,
- artifact_entity_name: str | None = None,
- artifact_project_name: str | None = None,
- ) -> tuple[Document, dict[str, Any]]:
- query_vars = [
- "$entityName: String!",
- "$projectName: String!",
- "$runName: String!",
- "$artifactID: ID!",
- ]
- query_args = [
- "entityName: $entityName",
- "projectName: $projectName",
- "runName: $runName",
- "artifactID: $artifactID",
- ]
- if use_as:
- query_vars.append("$usedAs: String")
- query_args.append("usedAs: $usedAs")
- entity_name = entity_name or self.settings("entity")
- project_name = project_name or self.settings("project")
- run_name = run_name or self.current_run_id
- variable_values: dict[str, Any] = {
- "entityName": entity_name,
- "projectName": project_name,
- "runName": run_name,
- "artifactID": artifact_id,
- "usedAs": use_as,
- }
- server_allows_entity_project_information = self._server_supports(
- ServerFeature.USE_ARTIFACT_WITH_ENTITY_AND_PROJECT_INFORMATION
- )
- if server_allows_entity_project_information:
- query_vars.extend(
- [
- "$artifactEntityName: String",
- "$artifactProjectName: String",
- ]
- )
- query_args.extend(
- [
- "artifactEntityName: $artifactEntityName",
- "artifactProjectName: $artifactProjectName",
- ]
- )
- variable_values["artifactEntityName"] = artifact_entity_name
- variable_values["artifactProjectName"] = artifact_project_name
- vars_str = ", ".join(query_vars)
- args_str = ", ".join(query_args)
- query = gql(
- f"""
- mutation UseArtifact({vars_str}) {{
- useArtifact(input: {{{args_str}}}) {{
- artifact {{
- id
- digest
- description
- state
- createdAt
- metadata
- }}
- }}
- }}
- """
- )
- return query, variable_values
- def use_artifact(
- self,
- artifact_id: str,
- entity_name: str | None = None,
- project_name: str | None = None,
- run_name: str | None = None,
- artifact_entity_name: str | None = None,
- artifact_project_name: str | None = None,
- use_as: str | None = None,
- ) -> dict[str, Any] | None:
- query, variable_values = self._construct_use_artifact_query(
- artifact_id,
- entity_name,
- project_name,
- run_name,
- use_as,
- artifact_entity_name,
- artifact_project_name,
- )
- response = self.gql(query, variable_values)
- if response["useArtifact"]["artifact"]:
- artifact: dict[str, Any] = response["useArtifact"]["artifact"]
- return artifact
- return None
- def create_artifact_type(
- self,
- artifact_type_name: str,
- entity_name: str | None = None,
- project_name: str | None = None,
- description: str | None = None,
- ) -> str | None:
- mutation = gql(
- """
- mutation CreateArtifactType(
- $entityName: String!,
- $projectName: String!,
- $artifactTypeName: String!,
- $description: String
- ) {
- createArtifactType(input: {
- entityName: $entityName,
- projectName: $projectName,
- name: $artifactTypeName,
- description: $description
- }) {
- artifactType {
- id
- }
- }
- }
- """
- )
- entity_name = entity_name or self.settings("entity")
- project_name = project_name or self.settings("project")
- response = self.gql(
- mutation,
- variable_values={
- "entityName": entity_name,
- "projectName": project_name,
- "artifactTypeName": artifact_type_name,
- "description": description,
- },
- )
- _id: str | None = response["createArtifactType"]["artifactType"]["id"]
- return _id
- def _get_create_artifact_mutation(
- self,
- history_step: int | None,
- distributed_id: str | None,
- ) -> str:
- types = ""
- values = ""
- if history_step not in [0, None]:
- types += "$historyStep: Int64!,"
- values += "historyStep: $historyStep,"
- if distributed_id:
- types += "$distributedID: String,"
- values += "distributedID: $distributedID,"
- query_template = """
- mutation CreateArtifact(
- $artifactTypeName: String!,
- $artifactCollectionNames: [String!],
- $entityName: String!,
- $projectName: String!,
- $runName: String,
- $description: String,
- $digest: String!,
- $aliases: [ArtifactAliasInput!],
- $metadata: JSONString,
- $clientID: ID,
- $sequenceClientID: ID,
- $ttlDurationSeconds: Int64,
- $tags: [TagInput!],
- _CREATE_ARTIFACT_ADDITIONAL_TYPE_
- ) {
- createArtifact(input: {
- artifactTypeName: $artifactTypeName,
- artifactCollectionNames: $artifactCollectionNames,
- entityName: $entityName,
- projectName: $projectName,
- runName: $runName,
- description: $description,
- digest: $digest,
- digestAlgorithm: MANIFEST_MD5,
- aliases: $aliases,
- metadata: $metadata,
- clientID: $clientID,
- sequenceClientID: $sequenceClientID,
- enableDigestDeduplication: true,
- ttlDurationSeconds: $ttlDurationSeconds,
- tags: $tags,
- _CREATE_ARTIFACT_ADDITIONAL_VALUE_
- }) {
- artifact {
- id
- state
- artifactSequence {
- id
- latestArtifact {
- id
- versionIndex
- }
- }
- }
- }
- }
- """
- return query_template.replace(
- "_CREATE_ARTIFACT_ADDITIONAL_TYPE_", types
- ).replace("_CREATE_ARTIFACT_ADDITIONAL_VALUE_", values)
- def create_artifact(
- self,
- artifact_type_name: str,
- artifact_collection_name: str,
- digest: str,
- client_id: str | None = None,
- sequence_client_id: str | None = None,
- entity_name: str | None = None,
- project_name: str | None = None,
- run_name: str | None = None,
- description: str | None = None,
- metadata: dict | None = None,
- ttl_duration_seconds: int | None = None,
- aliases: list[dict[str, str]] | None = None,
- tags: list[dict[str, str]] | None = None,
- distributed_id: str | None = None,
- is_user_created: bool | None = False,
- history_step: int | None = None,
- ) -> tuple[dict, dict]:
- query_template = self._get_create_artifact_mutation(
- history_step,
- distributed_id,
- )
- entity_name = entity_name or self.settings("entity")
- project_name = project_name or self.settings("project")
- if not is_user_created:
- run_name = run_name or self.current_run_id
- mutation = gql(query_template)
- response = self.gql(
- mutation,
- variable_values={
- "entityName": entity_name,
- "projectName": project_name,
- "runName": run_name,
- "artifactTypeName": artifact_type_name,
- "artifactCollectionNames": [artifact_collection_name],
- "clientID": client_id,
- "sequenceClientID": sequence_client_id,
- "digest": digest,
- "description": description,
- "aliases": list(aliases or []),
- "tags": list(tags or []),
- "metadata": json.dumps(util.make_safe_for_json(metadata))
- if metadata
- else None,
- "ttlDurationSeconds": ttl_duration_seconds,
- "distributedID": distributed_id,
- "historyStep": history_step,
- },
- )
- av = response["createArtifact"]["artifact"]
- latest = response["createArtifact"]["artifact"]["artifactSequence"].get(
- "latestArtifact"
- )
- return av, latest
- def commit_artifact(self, artifact_id: str) -> _Response:
- mutation = gql(
- """
- mutation CommitArtifact(
- $artifactID: ID!,
- ) {
- commitArtifact(input: {
- artifactID: $artifactID,
- }) {
- artifact {
- id
- digest
- }
- }
- }
- """
- )
- response: _Response = self.gql(
- mutation,
- variable_values={"artifactID": artifact_id},
- timeout=60,
- )
- return response
- def complete_multipart_upload_artifact(
- self,
- artifact_id: str,
- storage_path: str,
- completed_parts: list[dict[str, Any]],
- upload_id: str | None,
- complete_multipart_action: str = "Complete",
- ) -> str | None:
- mutation = gql(
- """
- mutation CompleteMultipartUploadArtifact(
- $completeMultipartAction: CompleteMultipartAction!,
- $completedParts: [UploadPartsInput!]!,
- $artifactID: ID!
- $storagePath: String!
- $uploadID: String!
- ) {
- completeMultipartUploadArtifact(
- input: {
- completeMultipartAction: $completeMultipartAction,
- completedParts: $completedParts,
- artifactID: $artifactID,
- storagePath: $storagePath
- uploadID: $uploadID
- }
- ) {
- digest
- }
- }
- """
- )
- response = self.gql(
- mutation,
- variable_values={
- "completeMultipartAction": complete_multipart_action,
- "artifactID": artifact_id,
- "storagePath": storage_path,
- "completedParts": completed_parts,
- "uploadID": upload_id,
- },
- )
- digest: str | None = response["completeMultipartUploadArtifact"]["digest"]
- return digest
- def create_artifact_manifest(
- self,
- name: str,
- digest: str,
- artifact_id: str | None,
- base_artifact_id: str | None = None,
- entity: str | None = None,
- project: str | None = None,
- run: str | None = None,
- include_upload: bool = True,
- type: str = "FULL",
- ) -> tuple[str, dict[str, Any]]:
- mutation = gql(
- """
- mutation CreateArtifactManifest(
- $name: String!,
- $digest: String!,
- $artifactID: ID!,
- $baseArtifactID: ID,
- $entityName: String!,
- $projectName: String!,
- $runName: String!,
- $includeUpload: Boolean!,
- {}
- ) {{
- createArtifactManifest(input: {{
- name: $name,
- digest: $digest,
- artifactID: $artifactID,
- baseArtifactID: $baseArtifactID,
- entityName: $entityName,
- projectName: $projectName,
- runName: $runName,
- {}
- }}) {{
- artifactManifest {{
- id
- file {{
- id
- name
- displayName
- uploadUrl @include(if: $includeUpload)
- uploadHeaders @include(if: $includeUpload)
- }}
- }}
- }}
- }}
- """.format(
- "$type: ArtifactManifestType = FULL" if type != "FULL" else "",
- "type: $type" if type != "FULL" else "",
- )
- )
- entity_name = entity or self.settings("entity")
- project_name = project or self.settings("project")
- run_name = run or self.current_run_id
- response = self.gql(
- mutation,
- variable_values={
- "name": name,
- "digest": digest,
- "artifactID": artifact_id,
- "baseArtifactID": base_artifact_id,
- "entityName": entity_name,
- "projectName": project_name,
- "runName": run_name,
- "includeUpload": include_upload,
- "type": type,
- },
- )
- return (
- response["createArtifactManifest"]["artifactManifest"]["id"],
- response["createArtifactManifest"]["artifactManifest"]["file"],
- )
- def update_artifact_manifest(
- self,
- artifact_manifest_id: str,
- base_artifact_id: str | None = None,
- digest: str | None = None,
- include_upload: bool | None = True,
- ) -> tuple[str, dict[str, Any]]:
- mutation = gql(
- """
- mutation UpdateArtifactManifest(
- $artifactManifestID: ID!,
- $digest: String,
- $baseArtifactID: ID,
- $includeUpload: Boolean!,
- ) {
- updateArtifactManifest(input: {
- artifactManifestID: $artifactManifestID,
- digest: $digest,
- baseArtifactID: $baseArtifactID,
- }) {
- artifactManifest {
- id
- file {
- id
- name
- displayName
- uploadUrl @include(if: $includeUpload)
- uploadHeaders @include(if: $includeUpload)
- }
- }
- }
- }
- """
- )
- response = self.gql(
- mutation,
- variable_values={
- "artifactManifestID": artifact_manifest_id,
- "digest": digest,
- "baseArtifactID": base_artifact_id,
- "includeUpload": include_upload,
- },
- )
- return (
- response["updateArtifactManifest"]["artifactManifest"]["id"],
- response["updateArtifactManifest"]["artifactManifest"]["file"],
- )
- def update_artifact_metadata(
- self, artifact_id: str, metadata: dict[str, Any]
- ) -> dict[str, Any]:
- """Set the metadata of the given artifact version."""
- mutation = gql(
- """
- mutation UpdateArtifact(
- $artifactID: ID!,
- $metadata: JSONString,
- ) {
- updateArtifact(input: {
- artifactID: $artifactID,
- metadata: $metadata,
- }) {
- artifact {
- id
- }
- }
- }
- """
- )
- response = self.gql(
- mutation,
- variable_values={
- "artifactID": artifact_id,
- "metadata": json.dumps(metadata),
- },
- )
- return response["updateArtifact"]["artifact"]
- def _resolve_client_id(
- self,
- client_id: str,
- ) -> str | None:
- if client_id in self._client_id_mapping:
- return self._client_id_mapping[client_id]
- query = gql(
- """
- query ClientIDMapping($clientID: ID!) {
- clientIDMapping(clientID: $clientID) {
- serverID
- }
- }
- """
- )
- response = self.gql(
- query,
- variable_values={
- "clientID": client_id,
- },
- )
- server_id = None
- if response is not None:
- client_id_mapping = response.get("clientIDMapping")
- if client_id_mapping is not None:
- server_id = client_id_mapping.get("serverID")
- if server_id is not None:
- self._client_id_mapping[client_id] = server_id
- return server_id
- @normalize_exceptions
- def create_artifact_files(
- self, artifact_files: Iterable[CreateArtifactFileSpecInput]
- ) -> Mapping[str, CreateArtifactFilesResponseFile]:
- query_template = """
- mutation CreateArtifactFiles(
- $storageLayout: ArtifactStorageLayout!
- $artifactFiles: [CreateArtifactFileSpecInput!]!
- ) {
- createArtifactFiles(input: {
- artifactFiles: $artifactFiles,
- storageLayout: $storageLayout,
- }) {
- files {
- edges {
- node {
- id
- name
- displayName
- uploadUrl
- uploadHeaders
- storagePath
- uploadMultipartUrls {
- uploadID
- uploadUrlParts {
- partNumber
- uploadUrl
- }
- }
- artifact {
- id
- }
- }
- }
- }
- }
- }
- """
- # TODO: we should use constants here from interface/artifacts.py
- # but probably don't want the dependency. We're going to remove
- # this setting in a future release, so I'm just hard-coding the strings.
- storage_layout = "V2"
- if env.get_use_v1_artifacts():
- storage_layout = "V1"
- mutation = gql(query_template)
- response = self.gql(
- mutation,
- variable_values={
- "storageLayout": storage_layout,
- "artifactFiles": [af for af in artifact_files],
- },
- )
- result = {}
- for edge in response["createArtifactFiles"]["files"]["edges"]:
- node = edge["node"]
- result[node["displayName"]] = node
- return result
- @normalize_exceptions
- def notify_scriptable_run_alert(
- self,
- title: str,
- text: str,
- level: str | None = None,
- wait_duration: Number | None = None,
- ) -> bool:
- mutation = gql(
- """
- mutation NotifyScriptableRunAlert(
- $entityName: String!,
- $projectName: String!,
- $runName: String!,
- $title: String!,
- $text: String!,
- $severity: AlertSeverity = INFO,
- $waitDuration: Duration
- ) {
- notifyScriptableRunAlert(input: {
- entityName: $entityName,
- projectName: $projectName,
- runName: $runName,
- title: $title,
- text: $text,
- severity: $severity,
- waitDuration: $waitDuration
- }) {
- success
- }
- }
- """
- )
- response = self.gql(
- mutation,
- variable_values={
- "entityName": self.settings("entity"),
- "projectName": self.settings("project"),
- "runName": self.current_run_id,
- "title": title,
- "text": text,
- "severity": level,
- "waitDuration": wait_duration,
- },
- )
- success: bool = response["notifyScriptableRunAlert"]["success"]
- return success
- def get_sweep_state(
- self, sweep: str, entity: str | None = None, project: str | None = None
- ) -> SweepState:
- query = gql(
- """
- query GetSweepState($entity: String, $project: String, $sweep: String!) {
- project(name: $project, entityName: $entity) {
- sweep(sweepName: $sweep) {
- state
- }
- }
- }
- """
- )
- response = self.gql(
- query,
- variable_values={
- "sweep": sweep,
- "entity": entity or self.settings("entity"),
- "project": project or self.settings("project"),
- },
- )
- return response["project"]["sweep"]["state"]
- def set_sweep_state(
- self,
- sweep: str,
- state: SweepState,
- entity: str | None = None,
- project: str | None = None,
- ) -> None:
- assert state in ("RUNNING", "PAUSED", "CANCELED", "FINISHED")
- s = self.sweep(sweep=sweep, entity=entity, project=project, specs="{}")
- curr_state = s["state"].upper()
- if state == "PAUSED" and curr_state not in ("PAUSED", "RUNNING"):
- raise Exception(f"Cannot pause {curr_state.lower()} sweep.")
- elif state != "RUNNING" and curr_state not in ("RUNNING", "PAUSED", "PENDING"):
- raise Exception(f"Sweep already {curr_state.lower()}.")
- sweep_id = s["id"]
- mutation = gql(
- """
- mutation UpsertSweep(
- $id: ID,
- $state: String,
- $entityName: String,
- $projectName: String
- ) {
- upsertSweep(input: {
- id: $id,
- state: $state,
- entityName: $entityName,
- projectName: $projectName
- }){
- sweep {
- name
- }
- }
- }
- """
- )
- self.gql(
- mutation,
- variable_values={
- "id": sweep_id,
- "state": state,
- "entityName": entity or self.settings("entity"),
- "projectName": project or self.settings("project"),
- },
- )
- def stop_sweep(
- self,
- sweep: str,
- entity: str | None = None,
- project: str | None = None,
- ) -> None:
- """Finish the sweep to stop running new runs and let currently running runs finish."""
- self.set_sweep_state(
- sweep=sweep, state="FINISHED", entity=entity, project=project
- )
- def cancel_sweep(
- self,
- sweep: str,
- entity: str | None = None,
- project: str | None = None,
- ) -> None:
- """Cancel the sweep to kill all running runs and stop running new runs."""
- self.set_sweep_state(
- sweep=sweep, state="CANCELED", entity=entity, project=project
- )
- def pause_sweep(
- self,
- sweep: str,
- entity: str | None = None,
- project: str | None = None,
- ) -> None:
- """Pause the sweep to temporarily stop running new runs."""
- self.set_sweep_state(
- sweep=sweep, state="PAUSED", entity=entity, project=project
- )
- def resume_sweep(
- self,
- sweep: str,
- entity: str | None = None,
- project: str | None = None,
- ) -> None:
- """Resume the sweep to continue running new runs."""
- self.set_sweep_state(
- sweep=sweep, state="RUNNING", entity=entity, project=project
- )
- def _status_request(self, url: str, length: int) -> requests.Response:
- """Ask google how much we've uploaded."""
- import requests
- check_httpclient_logger_handler()
- return requests.put(
- url=url,
- headers={"Content-Length": "0", "Content-Range": f"bytes */{length}"},
- )
- def _flatten_edges(self, response: _Response) -> list[dict]:
- """Return an array from the nested graphql relay structure."""
- return [node["node"] for node in response["edges"]]
- @normalize_exceptions
- def stop_run(
- self,
- run_id: str,
- ) -> bool:
- mutation = gql(
- """
- mutation stopRun($id: ID!) {
- stopRun(input: {
- id: $id
- }) {
- clientMutationId
- success
- }
- }
- """
- )
- response = self.gql(
- mutation,
- variable_values={
- "id": run_id,
- },
- )
- success: bool = response["stopRun"].get("success")
- return success
- @normalize_exceptions
- def create_custom_chart(
- self,
- entity: str,
- name: str,
- display_name: str,
- spec_type: str,
- access: str,
- spec: str | Mapping[str, Any],
- ) -> dict[str, Any] | None:
- if not isinstance(spec, str):
- spec = json.dumps(spec)
- mutation = gql(
- """
- mutation CreateCustomChart(
- $entity: String!
- $name: String!
- $displayName: String!
- $type: String!
- $access: String!
- $spec: JSONString!
- ) {
- createCustomChart(
- input: {
- entity: $entity
- name: $name
- displayName: $displayName
- type: $type
- access: $access
- spec: $spec
- }
- ) {
- chart { id }
- }
- }
- """
- )
- variable_values = {
- "entity": entity,
- "name": name,
- "displayName": display_name,
- "type": spec_type,
- "access": access,
- "spec": spec,
- }
- result: dict[str, Any] | None = self.gql(mutation, variable_values)[
- "createCustomChart"
- ]
- return result
|