internal_api.py 136 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109
  1. from __future__ import annotations
  2. import base64
  3. import datetime
  4. import functools
  5. import http.client
  6. import json
  7. import logging
  8. import os
  9. import re
  10. import socket
  11. import sys
  12. import threading
  13. from collections.abc import Iterable, Mapping, MutableMapping, Sequence
  14. from copy import deepcopy
  15. from pathlib import Path
  16. from typing import (
  17. IO,
  18. TYPE_CHECKING,
  19. Any,
  20. Callable,
  21. Literal,
  22. NamedTuple,
  23. TextIO,
  24. Union,
  25. overload,
  26. )
  27. import click
  28. from wandb_gql import Client, gql
  29. from wandb_gql.client import RetryError
  30. from wandb_graphql.language.ast import Document
  31. import wandb
  32. from wandb import env, util
  33. from wandb.analytics import get_sentry
  34. from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messages
  35. from wandb.errors import AuthenticationError, CommError, UsageError
  36. from wandb.integration.sagemaker import parse_sm_secrets
  37. from wandb.proto.wandb_internal_pb2 import ServerFeature
  38. from wandb.sdk import wandb_setup
  39. from wandb.sdk.internal import settings_static
  40. from wandb.sdk.internal._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
  41. from wandb.sdk.lib.gql_request import GraphQLSession
  42. from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
  43. from ..lib import retry, wbauth
  44. from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
  45. from . import context
  46. from .progress import Progress
  47. logger = logging.getLogger(__name__)
  48. LAUNCH_DEFAULT_PROJECT = "model-registry"
  49. if TYPE_CHECKING:
  50. from typing import Literal, TypedDict
  51. import requests
  52. from .progress import ProgressFn
  53. class CreateArtifactFileSpecInput(TypedDict, total=False):
  54. """Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql."""
  55. artifactID: str
  56. name: str
  57. md5: str
  58. mimetype: str | None
  59. artifactManifestID: str | None
  60. uploadPartsInput: list[dict[str, object]] | None
  61. class CreateArtifactFilesResponseFile(TypedDict):
  62. id: str
  63. name: str
  64. displayName: str
  65. uploadUrl: str | None
  66. uploadHeaders: Sequence[str]
  67. uploadMultipartUrls: UploadPartsResponse
  68. storagePath: str
  69. artifact: CreateArtifactFilesResponseFileNode
  70. class CreateArtifactFilesResponseFileNode(TypedDict):
  71. id: str
  72. class UploadPartsResponse(TypedDict):
  73. uploadUrlParts: list[UploadUrlParts]
  74. uploadID: str
  75. class UploadUrlParts(TypedDict):
  76. partNumber: int
  77. uploadUrl: str
  78. class CompleteMultipartUploadArtifactInput(TypedDict):
  79. """Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql."""
  80. completeMultipartAction: str
  81. completedParts: dict[int, str]
  82. artifactID: str
  83. storagePath: str
  84. uploadID: str
  85. md5: str
  86. class CompleteMultipartUploadArtifactResponse(TypedDict):
  87. digest: str
  88. class DefaultSettings(TypedDict, total=False):
  89. section: str
  90. git_remote: str
  91. ignore_globs: list[str]
  92. base_url: str
  93. root_dir: str | None
  94. api_key: str | None
  95. entity: str | None
  96. organization: str | None
  97. project: str | None
  98. _extra_http_headers: Mapping[str, str] | None
  99. _proxies: Mapping[str, str] | None
  100. _Response = MutableMapping
  101. SweepState = Literal["RUNNING", "PAUSED", "CANCELED", "FINISHED"]
  102. Number = Union[int, float]
  103. httpclient_logger = logging.getLogger("http.client")
  104. if os.environ.get("WANDB_DEBUG"):
  105. httpclient_logger.setLevel(logging.DEBUG)
  106. def check_httpclient_logger_handler() -> None:
  107. # Only enable http.client logging if WANDB_DEBUG is set
  108. if not os.environ.get("WANDB_DEBUG"):
  109. return
  110. if httpclient_logger.handlers:
  111. return
  112. # Enable HTTPConnection debug logging to the logging framework
  113. level = logging.DEBUG
  114. def httpclient_log(*args: Any) -> None:
  115. httpclient_logger.log(level, " ".join(args))
  116. # mask the print() built-in in the http.client module to use logging instead
  117. http.client.print = httpclient_log # type: ignore[attr-defined]
  118. # enable debugging
  119. http.client.HTTPConnection.debuglevel = 1
  120. root_logger = logging.getLogger("wandb")
  121. if root_logger.handlers:
  122. httpclient_logger.addHandler(root_logger.handlers[0])
  123. class _ThreadLocalData(threading.local):
  124. context: context.Context | None
  125. def __init__(self) -> None:
  126. self.context = None
  127. class _OrgNames(NamedTuple):
  128. entity_name: str
  129. display_name: str
  130. def _match_org_with_fetched_org_entities(
  131. organization: str, orgs: Sequence[_OrgNames]
  132. ) -> str:
  133. """Match the organization provided in the path with the org entity or org name of the input entity.
  134. Args:
  135. organization: The organization name to match
  136. orgs: list of tuples containing (org_entity_name, org_display_name)
  137. Returns:
  138. str: The matched org entity name
  139. Raises:
  140. ValueError: If no matching organization is found or if multiple orgs exist without a match
  141. """
  142. for org_names in orgs:
  143. if organization in org_names:
  144. return org_names.entity_name
  145. if len(orgs) == 1:
  146. raise ValueError(
  147. f"Expecting the organization name or entity name to match {orgs[0].display_name!r} "
  148. f"and cannot be linked/fetched with {organization!r}. "
  149. "Please update the target path with the correct organization name."
  150. )
  151. raise ValueError(
  152. "Personal entity belongs to multiple organizations "
  153. f"and cannot be linked/fetched with {organization!r}. "
  154. "Please update the target path with the correct organization name "
  155. "or use a team entity in the entity settings."
  156. )
  157. class Api:
  158. """W&B Internal Api wrapper.
  159. Note:
  160. Settings are automatically overridden by looking for
  161. a `wandb/settings` file in the current working directory or its parent
  162. directory. If none can be found, we look in the current user's home
  163. directory.
  164. Args:
  165. default_settings(dict, optional): If you aren't using a settings
  166. file, or you wish to override the section to use in the settings file
  167. Override the settings here.
  168. """
  169. HTTP_TIMEOUT = env.get_http_timeout(20)
  170. FILE_PUSHER_TIMEOUT = env.get_file_pusher_timeout()
  171. _global_context: context.Context
  172. _local_data: _ThreadLocalData
  173. def __init__(
  174. self,
  175. default_settings: (
  176. wandb.Settings #
  177. | settings_static.SettingsStatic
  178. | DefaultSettings
  179. | None
  180. ) = None,
  181. load_settings: bool = True,
  182. retry_timedelta: datetime.timedelta | None = None,
  183. environ: MutableMapping[str, str] = os.environ,
  184. retry_callback: Callable[[int, str], Any] | None = None,
  185. api_key: str | None = None,
  186. ) -> None:
  187. import requests
  188. self._environ = environ
  189. self._global_context = context.Context()
  190. self._local_data = _ThreadLocalData()
  191. default_overrides: dict[str, Any] = (
  192. dict(default_settings) if default_settings else {}
  193. )
  194. self.default_settings: DefaultSettings = {
  195. "section": default_overrides.get("section", "default"),
  196. "git_remote": default_overrides.get("git_remote", "origin"),
  197. "ignore_globs": default_overrides.get("ignore_globs", []),
  198. "base_url": default_overrides.get("base_url", "https://api.wandb.ai"),
  199. "root_dir": default_overrides.get("root_dir"),
  200. "api_key": default_overrides.get("api_key"),
  201. "entity": default_overrides.get("entity"),
  202. "organization": default_overrides.get("organization"),
  203. "project": default_overrides.get("project"),
  204. "_extra_http_headers": default_overrides.get("_extra_http_headers"),
  205. "_proxies": default_overrides.get("_proxies"),
  206. }
  207. if load_settings:
  208. global_settings = wandb_setup.singleton().settings
  209. if root_dir := self.default_settings["root_dir"]:
  210. global_settings = global_settings.model_copy()
  211. global_settings.root_dir = root_dir
  212. self._settings = global_settings.read_system_settings().all()
  213. else:
  214. self._settings = {}
  215. # Mutable settings set by the _file_stream_api
  216. self.dynamic_settings = {
  217. "system_sample_seconds": 2,
  218. "system_samples": 15,
  219. "heartbeat_seconds": 30,
  220. }
  221. self.retry_timedelta = retry_timedelta or datetime.timedelta(days=7)
  222. self.retry_uploads = 10
  223. # todo: remove these hacky hacks after settings refactor is complete
  224. # keeping this code here to limit scope and so that it is easy to remove later
  225. self._extra_http_headers = self.settings("_extra_http_headers") or json.loads(
  226. self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
  227. )
  228. auth = None
  229. api_key = api_key or self.default_settings.get("api_key")
  230. if api_key:
  231. auth = ("api", api_key)
  232. elif self.access_token is not None:
  233. self._extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
  234. else:
  235. auth = ("api", self.api_key or "")
  236. proxies = self.settings("_proxies") or json.loads(
  237. self._environ.get("WANDB__PROXIES", "{}")
  238. )
  239. self.client = Client(
  240. transport=GraphQLSession(
  241. headers={
  242. "User-Agent": self.user_agent,
  243. "X-WANDB-USERNAME": env.get_username(env=self._environ),
  244. "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
  245. **self._extra_http_headers,
  246. },
  247. use_json=True,
  248. # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
  249. # https://bugs.python.org/issue22889
  250. timeout=self.HTTP_TIMEOUT,
  251. auth=auth,
  252. url=f"{self.settings('base_url')}/graphql",
  253. proxies=proxies,
  254. )
  255. )
  256. self.retry_callback = retry_callback
  257. self._retry_gql = retry.Retry(
  258. self.execute,
  259. retry_timedelta=retry_timedelta,
  260. check_retry_fn=util.no_retry_auth,
  261. retryable_exceptions=(RetryError, requests.RequestException),
  262. retry_callback=retry_callback,
  263. )
  264. self._current_run_id: str | None = None
  265. self._file_stream_api = None
  266. self._upload_file_session = requests.Session()
  267. if self.FILE_PUSHER_TIMEOUT:
  268. self._upload_file_session.put = functools.partial( # type: ignore
  269. self._upload_file_session.put,
  270. timeout=self.FILE_PUSHER_TIMEOUT,
  271. )
  272. if proxies:
  273. self._upload_file_session.proxies.update(proxies)
  274. # This Retry class is initialized once for each Api instance, so this
  275. # defaults to retrying 1 million times per process or 7 days
  276. self.upload_file_retry = normalize_exceptions(
  277. retry.retriable(retry_timedelta=retry_timedelta)(self.upload_file)
  278. )
  279. self.upload_multipart_file_chunk_retry = normalize_exceptions(
  280. retry.retriable(retry_timedelta=retry_timedelta)(
  281. self.upload_multipart_file_chunk
  282. )
  283. )
  284. self._client_id_mapping: dict[str, str] = {}
  285. # Large file uploads to azure can optionally use their SDK
  286. self._azure_blob_module = util.get_module("azure.storage.blob")
  287. self._max_cli_version: str | None = None
  288. self._server_features_cache: dict[str, bool] | None = None
  289. def gql(self, *args: Any, **kwargs: Any) -> Any:
  290. ret = self._retry_gql(
  291. *args,
  292. retry_cancel_event=self.context.cancel_event,
  293. **kwargs,
  294. )
  295. return ret
  296. def set_local_context(self, api_context: context.Context | None) -> None:
  297. self._local_data.context = api_context
  298. def clear_local_context(self) -> None:
  299. self._local_data.context = None
  300. @property
  301. def context(self) -> context.Context:
  302. return self._local_data.context or self._global_context
  303. def reauth(self) -> None:
  304. """Ensure the current api key is set in the transport."""
  305. self.client.transport.session.auth = ("api", self.api_key or "")
  306. def relocate(self) -> None:
  307. """Ensure the current api points to the right server."""
  308. self.client.transport.url = "{}/graphql".format(self.settings("base_url"))
  309. def execute(self, *args: Any, **kwargs: Any) -> _Response:
  310. """Wrapper around execute that logs in cases of failure."""
  311. import requests
  312. try:
  313. return self.client.execute(*args, **kwargs) # type: ignore
  314. except requests.exceptions.HTTPError as err:
  315. response = err.response
  316. assert response is not None
  317. logger.exception("Error executing GraphQL.")
  318. for error in parse_backend_error_messages(response):
  319. wandb.termerror(f"Error while calling W&B API: {error} ({response})")
  320. raise
  321. def validate_api_key(self) -> bool:
  322. """Returns whether the API key stored on initialization is valid."""
  323. res = self.gql(gql("query { viewer { id } }"))
  324. return res is not None and res["viewer"] is not None
  325. def set_current_run_id(self, run_id: str) -> None:
  326. self._current_run_id = run_id
  327. @property
  328. def current_run_id(self) -> str | None:
  329. return self._current_run_id
  330. @property
  331. def user_agent(self) -> str:
  332. return f"W&B Internal Client {wandb.__version__}"
  333. @property
  334. def api_key(self) -> str | None:
  335. from wandb.sdk.lib import wbauth
  336. if ( #
  337. (auth := wbauth.session_credentials(host=self.api_url))
  338. and isinstance(auth, wbauth.AuthApiKey)
  339. ):
  340. return auth.api_key
  341. return (
  342. os.getenv(env.API_KEY)
  343. or wbauth.read_netrc_auth(host=self.api_url)
  344. or parse_sm_secrets().get(env.API_KEY)
  345. or self.default_settings.get("api_key")
  346. )
  347. @property
  348. def access_token(self) -> str | None:
  349. """Retrieves an access token for authentication.
  350. This function attempts to exchange an identity token for a temporary
  351. access token from the server, and save it to the credentials file.
  352. It uses the path to the identity token as defined in the environment
  353. variables. If the environment variable is not set, it returns None.
  354. Returns:
  355. str | None: The access token if available, otherwise None if
  356. no identity token is supplied.
  357. Raises:
  358. AuthenticationError: If the path to the identity token is not found.
  359. """
  360. token_file_str = self._environ.get(env.IDENTITY_TOKEN_FILE)
  361. if not token_file_str:
  362. return None
  363. token_file = Path(token_file_str)
  364. if not token_file.exists():
  365. raise AuthenticationError(f"Identity token file not found: {token_file}")
  366. auth = wbauth.AuthIdentityTokenFile(
  367. host=self.settings("base_url"),
  368. path=str(token_file),
  369. credentials_file=wandb_setup.singleton().settings.credentials_file,
  370. )
  371. return auth.fetch_access_token()
  372. @property
  373. def api_url(self) -> str:
  374. return self.settings("base_url") # type: ignore
  375. @property
  376. def app_url(self) -> str:
  377. return wandb.util.app_url(self.api_url)
  378. @property
  379. def default_entity(self) -> str:
  380. return self.viewer().get("entity") # type: ignore
  381. @overload
  382. def settings(self, key: None = None) -> dict[str, Any]: ...
  383. @overload
  384. def settings(self, key: str) -> Any: ...
  385. def settings(self, key: str | None = None) -> Any:
  386. """The settings overridden from the wandb/settings file.
  387. Args:
  388. key (str, optional): If provided only this setting is returned
  389. section (str, optional): If provided this section of the setting file is
  390. used, defaults to "default"
  391. Returns:
  392. A dict with the current settings
  393. {
  394. "entity": "models",
  395. "base_url": "https://api.wandb.ai",
  396. "project": None,
  397. "organization": "my-org",
  398. }
  399. """
  400. result: dict[str, Any] = dict(self.default_settings)
  401. result.update(self._settings)
  402. result.update(
  403. {
  404. "entity": env.get_entity(
  405. self._settings.get(
  406. "entity",
  407. result.get("entity"),
  408. ),
  409. env=self._environ,
  410. ),
  411. "organization": env.get_organization(
  412. self._settings.get(
  413. "organization",
  414. result.get("organization"),
  415. ),
  416. env=self._environ,
  417. ),
  418. "project": env.get_project(
  419. self._settings.get(
  420. "project",
  421. result.get("project"),
  422. ),
  423. env=self._environ,
  424. ),
  425. "base_url": env.get_base_url(
  426. self._settings.get(
  427. "base_url",
  428. result.get("base_url"),
  429. ),
  430. env=self._environ,
  431. ),
  432. }
  433. )
  434. return result if key is None else result[key]
  435. def clear_setting(self, key: str) -> None:
  436. self._settings.pop(key, None)
  437. def set_setting(self, key: str, value: Any) -> None:
  438. self._settings[key] = value
  439. if key == "entity":
  440. env.set_entity(value, env=self._environ)
  441. elif key == "project":
  442. env.set_project(value, env=self._environ)
  443. elif key == "base_url":
  444. self.relocate()
  445. def parse_slug(
  446. self, slug: str, project: str | None = None, run: str | None = None
  447. ) -> tuple[str, str]:
  448. """Parse a slug into a project and run.
  449. Args:
  450. slug (str): The slug to parse
  451. project (str, optional): The project to use, if not provided it will be
  452. inferred from the slug
  453. run (str, optional): The run to use, if not provided it will be inferred
  454. from the slug
  455. Returns:
  456. A dict with the project and run
  457. """
  458. if slug and "/" in slug:
  459. parts = slug.split("/")
  460. project = parts[0]
  461. run = parts[1]
  462. else:
  463. project = project or self.settings().get("project")
  464. if project is None:
  465. raise CommError("No default project configured.")
  466. run = run or slug or self.current_run_id or env.get_run(env=self._environ)
  467. assert run, "run must be specified"
  468. return project, run
  469. @normalize_exceptions
  470. def fail_run_queue_item(
  471. self,
  472. run_queue_item_id: str,
  473. message: str,
  474. stage: str,
  475. file_paths: list[str] | None = None,
  476. ) -> bool:
  477. variable_values: dict[str, str | (list[str] | None)] = {
  478. "runQueueItemId": run_queue_item_id,
  479. "message": message,
  480. "stage": stage,
  481. }
  482. if file_paths is not None:
  483. variable_values["filePaths"] = file_paths
  484. mutation_string = """
  485. mutation failRunQueueItem($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
  486. failRunQueueItem(
  487. input: {
  488. runQueueItemId: $runQueueItemId
  489. message: $message
  490. stage: $stage
  491. filePaths: $filePaths
  492. }
  493. ) {
  494. success
  495. }
  496. }
  497. """
  498. mutation = gql(mutation_string)
  499. response = self.gql(mutation, variable_values=variable_values)
  500. result: bool = response["failRunQueueItem"]["success"]
  501. return result
  502. def _server_features(self) -> dict[str, bool]:
  503. # NOTE: Avoid caching via `@cached_property`, due to undocumented
  504. # locking behavior before Python 3.12.
  505. # See: https://github.com/python/cpython/issues/87634
  506. query = gql(SERVER_FEATURES_QUERY_GQL)
  507. try:
  508. response = self.gql(query)
  509. except Exception as e:
  510. # Unfortunately we currently have to match on the text of the error message,
  511. # as the `gql` client raises `Exception` rather than a more specific error.
  512. if 'Cannot query field "features" on type "ServerInfo".' in str(e):
  513. self._server_features_cache = {}
  514. else:
  515. raise
  516. else:
  517. info = ServerFeaturesQuery.model_validate(response).server_info
  518. if info and (feats := info.features):
  519. self._server_features_cache = {f.name: f.is_enabled for f in feats if f}
  520. else:
  521. self._server_features_cache = {}
  522. return self._server_features_cache
  523. def _server_supports(self, feature: int | str) -> bool:
  524. """Return whether the current server supports the given feature.
  525. NOTE: This is deprecated. Outside of this file, please use
  526. `ServiceApi.feature_enabled()`. The `ServiceApi` is a sort of
  527. replacement to this "internal" `Api` class.
  528. This also caches the underlying lookup of server feature flags,
  529. and it maps {feature_name (str) -> is_enabled (bool)}.
  530. Good to use for features that have a fallback mechanism for older servers.
  531. """
  532. # If we're given the protobuf enum value, convert to a string name.
  533. # NOTE: We deliberately use names (str) instead of enum values (int)
  534. # as the keys here, since:
  535. # - the server identifies features by their name, rather than (client-side) enum value
  536. # - the defined list of client-side flags may be behind the server-side list of flags
  537. key = ServerFeature.Name(feature) if isinstance(feature, int) else feature
  538. return self._server_features().get(key) or False
  539. @normalize_exceptions
  540. def update_run_queue_item_warning(
  541. self,
  542. run_queue_item_id: str,
  543. message: str,
  544. stage: str,
  545. file_paths: list[str] | None = None,
  546. ) -> bool:
  547. mutation = gql(
  548. """
  549. mutation updateRunQueueItemWarning($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
  550. updateRunQueueItemWarning(
  551. input: {
  552. runQueueItemId: $runQueueItemId
  553. message: $message
  554. stage: $stage
  555. filePaths: $filePaths
  556. }
  557. ) {
  558. success
  559. }
  560. }
  561. """
  562. )
  563. response = self.gql(
  564. mutation,
  565. variable_values={
  566. "runQueueItemId": run_queue_item_id,
  567. "message": message,
  568. "stage": stage,
  569. "filePaths": file_paths,
  570. },
  571. )
  572. result: bool = response["updateRunQueueItemWarning"]["success"]
  573. return result
  574. @normalize_exceptions
  575. def viewer(self) -> dict[str, Any]:
  576. query = gql(
  577. """
  578. query Viewer{
  579. viewer {
  580. id
  581. entity
  582. username
  583. flags
  584. teams {
  585. edges {
  586. node {
  587. name
  588. }
  589. }
  590. }
  591. }
  592. }
  593. """
  594. )
  595. res = self.gql(query)
  596. return res.get("viewer") or {}
  597. @normalize_exceptions
  598. def max_cli_version(self) -> str | None:
  599. if self._max_cli_version is not None:
  600. return self._max_cli_version
  601. _, server_info = self.viewer_server_info()
  602. self._max_cli_version = server_info.get("cliVersionInfo", {}).get(
  603. "max_cli_version"
  604. )
  605. return self._max_cli_version
  606. @normalize_exceptions
  607. def viewer_server_info(self) -> tuple[dict[str, Any], dict[str, Any]]:
  608. query = gql(
  609. """
  610. query Viewer{
  611. viewer {
  612. id
  613. entity
  614. username
  615. email
  616. flags
  617. teams {
  618. edges {
  619. node {
  620. name
  621. }
  622. }
  623. }
  624. }
  625. serverInfo {
  626. cliVersionInfo
  627. latestLocalVersionInfo {
  628. outOfDate
  629. latestVersionString
  630. versionOnThisInstanceString
  631. }
  632. }
  633. }
  634. """
  635. )
  636. res = self.gql(query)
  637. return res.get("viewer") or {}, res.get("serverInfo") or {}
  638. @normalize_exceptions
  639. def list_projects(self, entity: str | None = None) -> list[dict[str, str]]:
  640. """List projects in W&B scoped by entity.
  641. Args:
  642. entity (str, optional): The entity to scope this project to.
  643. Returns:
  644. [{"id","name","description"}]
  645. """
  646. query = gql(
  647. """
  648. query EntityProjects($entity: String) {
  649. models(first: 10, entityName: $entity) {
  650. edges {
  651. node {
  652. id
  653. name
  654. description
  655. }
  656. }
  657. }
  658. }
  659. """
  660. )
  661. project_list: list[dict[str, str]] = self._flatten_edges(
  662. self.gql(
  663. query, variable_values={"entity": entity or self.settings("entity")}
  664. )["models"]
  665. )
  666. return project_list
  667. @normalize_exceptions
  668. def project(self, project: str, entity: str | None = None) -> _Response:
  669. """Retrieve project.
  670. Args:
  671. project (str): The project to get details for
  672. entity (str, optional): The entity to scope this project to.
  673. Returns:
  674. [{"id","name","repo","dockerImage","description"}]
  675. """
  676. query = gql(
  677. """
  678. query ProjectDetails($entity: String, $project: String) {
  679. model(name: $project, entityName: $entity) {
  680. id
  681. name
  682. repo
  683. dockerImage
  684. description
  685. }
  686. }
  687. """
  688. )
  689. response: _Response = self.gql(
  690. query, variable_values={"entity": entity, "project": project}
  691. )["model"]
  692. return response
  693. @normalize_exceptions
  694. def sweep(
  695. self,
  696. sweep: str,
  697. specs: str,
  698. project: str | None = None,
  699. entity: str | None = None,
  700. ) -> dict[str, Any]:
  701. """Retrieve sweep.
  702. Args:
  703. sweep (str): The sweep to get details for
  704. specs (str): history specs
  705. project (str, optional): The project to scope this sweep to.
  706. entity (str, optional): The entity to scope this sweep to.
  707. Returns:
  708. [{"id","name","repo","dockerImage","description"}]
  709. """
  710. query = gql(
  711. """
  712. query SweepWithRuns($entity: String, $project: String, $sweep: String!, $specs: [JSONString!]!) {
  713. project(name: $project, entityName: $entity) {
  714. sweep(sweepName: $sweep) {
  715. id
  716. name
  717. method
  718. state
  719. description
  720. config
  721. createdAt
  722. heartbeatAt
  723. updatedAt
  724. earlyStopJobRunning
  725. bestLoss
  726. controller
  727. scheduler
  728. runs {
  729. edges {
  730. node {
  731. name
  732. state
  733. config
  734. exitcode
  735. heartbeatAt
  736. shouldStop
  737. failed
  738. stopped
  739. running
  740. summaryMetrics
  741. sampledHistory(specs: $specs)
  742. }
  743. }
  744. }
  745. }
  746. }
  747. }
  748. """
  749. )
  750. entity = entity or self.settings("entity")
  751. project = project or self.settings("project")
  752. response = self.gql(
  753. query,
  754. variable_values={
  755. "entity": entity,
  756. "project": project,
  757. "sweep": sweep,
  758. "specs": specs,
  759. },
  760. )
  761. if response["project"] is None or response["project"]["sweep"] is None:
  762. raise ValueError(f"Sweep {entity}/{project}/{sweep} not found")
  763. data: dict[str, Any] = response["project"]["sweep"]
  764. if data:
  765. data["runs"] = self._flatten_edges(data["runs"])
  766. return data
  767. @normalize_exceptions
  768. def list_runs(
  769. self, project: str, entity: str | None = None
  770. ) -> list[dict[str, str]]:
  771. """List runs in W&B scoped by project.
  772. Args:
  773. project (str): The project to scope the runs to
  774. entity (str, optional): The entity to scope this project to. Defaults to public models
  775. Returns:
  776. [{"id","name","description"}]
  777. """
  778. query = gql(
  779. """
  780. query ProjectRuns($model: String!, $entity: String) {
  781. model(name: $model, entityName: $entity) {
  782. buckets(first: 10) {
  783. edges {
  784. node {
  785. id
  786. name
  787. displayName
  788. description
  789. }
  790. }
  791. }
  792. }
  793. }
  794. """
  795. )
  796. return self._flatten_edges(
  797. self.gql(
  798. query,
  799. variable_values={
  800. "entity": entity or self.settings("entity"),
  801. "model": project or self.settings("project"),
  802. },
  803. )["model"]["buckets"]
  804. )
  805. @normalize_exceptions
  806. def run_config(
  807. self, project: str, run: str | None = None, entity: str | None = None
  808. ) -> tuple[str, dict[str, Any], str | None, dict[str, Any]]:
  809. """Get the relevant configs for a run.
  810. Args:
  811. project (str): The project to download, (can include bucket)
  812. run (str, optional): The run to download
  813. entity (str, optional): The entity to scope this project to.
  814. """
  815. import requests
  816. check_httpclient_logger_handler()
  817. query = gql(
  818. """
  819. query RunConfigs(
  820. $name: String!,
  821. $entity: String,
  822. $run: String!,
  823. $pattern: String!,
  824. $includeConfig: Boolean!,
  825. ) {
  826. model(name: $name, entityName: $entity) {
  827. bucket(name: $run) {
  828. config @include(if: $includeConfig)
  829. commit @include(if: $includeConfig)
  830. files(pattern: $pattern) {
  831. pageInfo {
  832. hasNextPage
  833. endCursor
  834. }
  835. edges {
  836. node {
  837. name
  838. directUrl
  839. }
  840. }
  841. }
  842. }
  843. }
  844. }
  845. """
  846. )
  847. variable_values = {
  848. "name": project,
  849. "run": run,
  850. "entity": entity,
  851. "includeConfig": True,
  852. }
  853. commit: str = ""
  854. config: dict[str, Any] = {}
  855. patch: str | None = None
  856. metadata: dict[str, Any] = {}
  857. # If we use the `names` parameter on the `files` node, then the server
  858. # will helpfully give us and 'open' file handle to the files that don't
  859. # exist. This is so that we can upload data to it. However, in this
  860. # case, we just want to download that file and not upload to it, so
  861. # let's instead query for the files that do exist using `pattern`
  862. # (with no wildcards).
  863. #
  864. # Unfortunately we're unable to construct a single pattern that matches
  865. # our 2 files, we would need something like regex for that.
  866. for filename in [DIFF_FNAME, METADATA_FNAME]:
  867. variable_values["pattern"] = filename
  868. response = self.gql(query, variable_values=variable_values)
  869. if response["model"] is None:
  870. raise CommError(f"Run {entity}/{project}/{run} not found")
  871. run_obj: dict = response["model"]["bucket"]
  872. # we only need to fetch this config once
  873. if variable_values["includeConfig"]:
  874. commit = run_obj["commit"]
  875. config = json.loads(run_obj["config"] or "{}")
  876. variable_values["includeConfig"] = False
  877. if run_obj["files"] is not None:
  878. for file_edge in run_obj["files"]["edges"]:
  879. name = file_edge["node"]["name"]
  880. url = file_edge["node"]["directUrl"]
  881. res = requests.get(url)
  882. res.raise_for_status()
  883. if name == METADATA_FNAME:
  884. metadata = res.json()
  885. elif name == DIFF_FNAME:
  886. patch = res.text
  887. return commit, config, patch, metadata
  888. @normalize_exceptions
  889. def run_resume_status(
  890. self, entity: str, project_name: str, name: str
  891. ) -> dict[str, Any] | None:
  892. """Check if a run exists and get resume information.
  893. Args:
  894. entity (str): The entity to scope this project to.
  895. project_name (str): The project to download, (can include bucket)
  896. name (str): The run to download
  897. """
  898. # Pulling wandbConfig.start_time is required so that we can determine if a run has actually started
  899. query = gql(
  900. """
  901. query RunResumeStatus($project: String, $entity: String, $name: String!) {
  902. model(name: $project, entityName: $entity) {
  903. id
  904. name
  905. entity {
  906. id
  907. name
  908. }
  909. bucket(name: $name, missingOk: true) {
  910. id
  911. name
  912. summaryMetrics
  913. displayName
  914. logLineCount
  915. historyLineCount
  916. eventsLineCount
  917. historyTail
  918. eventsTail
  919. config
  920. tags
  921. wandbConfig(keys: ["t"])
  922. }
  923. }
  924. }
  925. """
  926. )
  927. response = self.gql(
  928. query,
  929. variable_values={
  930. "entity": entity,
  931. "project": project_name,
  932. "name": name,
  933. },
  934. )
  935. if "model" not in response or "bucket" not in (response["model"] or {}):
  936. return None
  937. project = response["model"]
  938. self.set_setting("project", project_name)
  939. if "entity" in project:
  940. self.set_setting("entity", project["entity"]["name"])
  941. result: dict[str, Any] = project["bucket"]
  942. return result
  943. @normalize_exceptions
  944. def check_stop_requested(
  945. self, project_name: str, entity_name: str, run_id: str
  946. ) -> bool:
  947. query = gql(
  948. """
  949. query RunStoppedStatus($projectName: String, $entityName: String, $runId: String!) {
  950. project(name:$projectName, entityName:$entityName) {
  951. run(name:$runId) {
  952. stopped
  953. }
  954. }
  955. }
  956. """
  957. )
  958. response = self.gql(
  959. query,
  960. variable_values={
  961. "projectName": project_name,
  962. "entityName": entity_name,
  963. "runId": run_id,
  964. },
  965. )
  966. project = response.get("project", None)
  967. if not project:
  968. return False
  969. run = project.get("run", None)
  970. if not run:
  971. return False
  972. status: bool = run["stopped"]
  973. return status
  974. def format_project(self, project: str) -> str:
  975. return re.sub(r"\W+", "-", project.lower()).strip("-_")
  976. @normalize_exceptions
  977. def upsert_project(
  978. self,
  979. project: str,
  980. id: str | None = None,
  981. description: str | None = None,
  982. entity: str | None = None,
  983. ) -> dict[str, Any]:
  984. """Create a new project.
  985. Args:
  986. project (str): The project to create
  987. description (str, optional): A description of this project
  988. entity (str, optional): The entity to scope this project to.
  989. """
  990. mutation = gql(
  991. """
  992. mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String) {
  993. upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
  994. model {
  995. name
  996. description
  997. }
  998. }
  999. }
  1000. """
  1001. )
  1002. response = self.gql(
  1003. mutation,
  1004. variable_values={
  1005. "name": self.format_project(project),
  1006. "entity": entity or self.settings("entity"),
  1007. "description": description,
  1008. "id": id,
  1009. },
  1010. )
  1011. result: dict[str, Any] = response["upsertModel"]["model"]
  1012. return result
  1013. @normalize_exceptions
  1014. def entity_is_team(self, entity: str) -> bool:
  1015. query = gql(
  1016. """
  1017. query EntityIsTeam($entity: String!) {
  1018. entity(name: $entity) {
  1019. id
  1020. isTeam
  1021. }
  1022. }
  1023. """
  1024. )
  1025. variable_values = {
  1026. "entity": entity,
  1027. }
  1028. res = self.gql(query, variable_values)
  1029. if res.get("entity") is None:
  1030. raise Exception(
  1031. f"Error fetching entity {entity} "
  1032. "check that you have access to this entity"
  1033. )
  1034. is_team: bool = res["entity"]["isTeam"]
  1035. return is_team
  1036. @normalize_exceptions
  1037. def get_project_run_queues(self, entity: str, project: str) -> list[dict[str, str]]:
  1038. query = gql(
  1039. """
  1040. query ProjectRunQueues($entity: String!, $projectName: String!){
  1041. project(entityName: $entity, name: $projectName) {
  1042. runQueues {
  1043. id
  1044. name
  1045. createdBy
  1046. access
  1047. }
  1048. }
  1049. }
  1050. """
  1051. )
  1052. variable_values = {
  1053. "projectName": project,
  1054. "entity": entity,
  1055. }
  1056. res = self.gql(query, variable_values)
  1057. if res.get("project") is None:
  1058. # circular dependency: (LAUNCH_DEFAULT_PROJECT = model-registry)
  1059. if project == "model-registry":
  1060. msg = (
  1061. f"Error fetching run queues for {entity} "
  1062. "check that you have access to this entity and project"
  1063. )
  1064. else:
  1065. msg = (
  1066. f"Error fetching run queues for {entity}/{project} "
  1067. "check that you have access to this entity and project"
  1068. )
  1069. raise Exception(msg)
  1070. project_run_queues: list[dict[str, str]] = res["project"]["runQueues"]
  1071. return project_run_queues
  1072. @normalize_exceptions
  1073. def create_default_resource_config(
  1074. self,
  1075. entity: str,
  1076. resource: str,
  1077. config: str,
  1078. template_variables: dict[str, float | int | str] | None,
  1079. ) -> dict[str, Any] | None:
  1080. mutation_params = """
  1081. $entityName: String!,
  1082. $resource: String!,
  1083. $config: JSONString!,
  1084. $templateVariables: JSONString
  1085. """
  1086. mutation_inputs = """
  1087. entityName: $entityName,
  1088. resource: $resource,
  1089. config: $config,
  1090. templateVariables: $templateVariables
  1091. """
  1092. variable_values = {
  1093. "entityName": entity,
  1094. "resource": resource,
  1095. "config": config,
  1096. }
  1097. if template_variables is not None:
  1098. variable_values["templateVariables"] = json.dumps(template_variables)
  1099. else:
  1100. variable_values["templateVariables"] = "{}"
  1101. query = gql(
  1102. f"""
  1103. mutation createDefaultResourceConfig(
  1104. {mutation_params}
  1105. ) {{
  1106. createDefaultResourceConfig(
  1107. input: {{
  1108. {mutation_inputs}
  1109. }}
  1110. ) {{
  1111. defaultResourceConfigID
  1112. success
  1113. }}
  1114. }}
  1115. """
  1116. )
  1117. result: dict[str, Any] | None = self.gql(query, variable_values)[
  1118. "createDefaultResourceConfig"
  1119. ]
  1120. return result
  1121. @normalize_exceptions
  1122. def create_run_queue(
  1123. self,
  1124. entity: str,
  1125. project: str,
  1126. queue_name: str,
  1127. access: str,
  1128. prioritization_mode: str | None = None,
  1129. config_id: str | None = None,
  1130. ) -> dict[str, Any] | None:
  1131. query = gql(
  1132. """
  1133. mutation createRunQueue(
  1134. $entity: String!,
  1135. $project: String!,
  1136. $queueName: String!,
  1137. $access: RunQueueAccessType!,
  1138. $prioritizationMode: RunQueuePrioritizationMode,
  1139. $defaultResourceConfigID: ID,
  1140. ) {
  1141. createRunQueue(
  1142. input: {
  1143. entityName: $entity,
  1144. projectName: $project,
  1145. queueName: $queueName,
  1146. access: $access,
  1147. prioritizationMode: $prioritizationMode
  1148. defaultResourceConfigID: $defaultResourceConfigID
  1149. }
  1150. ) {
  1151. success
  1152. queueID
  1153. }
  1154. }
  1155. """
  1156. )
  1157. variable_values = {
  1158. "entity": entity,
  1159. "project": project,
  1160. "queueName": queue_name,
  1161. "access": access,
  1162. "prioritizationMode": prioritization_mode,
  1163. "defaultResourceConfigID": config_id,
  1164. }
  1165. result: dict[str, Any] | None = self.gql(query, variable_values)[
  1166. "createRunQueue"
  1167. ]
  1168. return result
  1169. @normalize_exceptions
  1170. def upsert_run_queue(
  1171. self,
  1172. queue_name: str,
  1173. entity: str,
  1174. resource_type: str,
  1175. resource_config: dict,
  1176. project: str = LAUNCH_DEFAULT_PROJECT,
  1177. prioritization_mode: str | None = None,
  1178. template_variables: dict | None = None,
  1179. external_links: dict | None = None,
  1180. ) -> dict[str, Any] | None:
  1181. query = gql(
  1182. """
  1183. mutation upsertRunQueue(
  1184. $entityName: String!
  1185. $projectName: String!
  1186. $queueName: String!
  1187. $resourceType: String!
  1188. $resourceConfig: JSONString!
  1189. $templateVariables: JSONString
  1190. $prioritizationMode: RunQueuePrioritizationMode
  1191. $externalLinks: JSONString
  1192. $clientMutationId: String
  1193. ) {
  1194. upsertRunQueue(
  1195. input: {
  1196. entityName: $entityName
  1197. projectName: $projectName
  1198. queueName: $queueName
  1199. resourceType: $resourceType
  1200. resourceConfig: $resourceConfig
  1201. templateVariables: $templateVariables
  1202. prioritizationMode: $prioritizationMode
  1203. externalLinks: $externalLinks
  1204. clientMutationId: $clientMutationId
  1205. }
  1206. ) {
  1207. success
  1208. configSchemaValidationErrors
  1209. }
  1210. }
  1211. """
  1212. )
  1213. variable_values = {
  1214. "entityName": entity,
  1215. "projectName": project,
  1216. "queueName": queue_name,
  1217. "resourceType": resource_type,
  1218. "resourceConfig": json.dumps(resource_config),
  1219. "templateVariables": (
  1220. json.dumps(template_variables) if template_variables else None
  1221. ),
  1222. "prioritizationMode": prioritization_mode,
  1223. "externalLinks": json.dumps(external_links) if external_links else None,
  1224. }
  1225. result: dict[str, Any] = self.gql(query, variable_values)
  1226. return result["upsertRunQueue"]
  1227. @normalize_exceptions
  1228. def push_to_run_queue_by_name(
  1229. self,
  1230. entity: str,
  1231. project: str,
  1232. queue_name: str,
  1233. run_spec: str,
  1234. template_variables: dict[str, int | float | str] | None,
  1235. priority: int | None = None,
  1236. ) -> dict[str, Any] | None:
  1237. mutation_params = """
  1238. $entityName: String!,
  1239. $projectName: String!,
  1240. $queueName: String!,
  1241. $runSpec: JSONString!
  1242. """
  1243. mutation_input = """
  1244. entityName: $entityName,
  1245. projectName: $projectName,
  1246. queueName: $queueName,
  1247. runSpec: $runSpec
  1248. """
  1249. variables: dict[str, Any] = {
  1250. "entityName": entity,
  1251. "projectName": project,
  1252. "queueName": queue_name,
  1253. "runSpec": run_spec,
  1254. }
  1255. if priority is not None:
  1256. variables["priority"] = priority
  1257. mutation_params += ", $priority: Int"
  1258. mutation_input += ", priority: $priority"
  1259. if template_variables is not None:
  1260. variables.update({"templateVariableValues": json.dumps(template_variables)})
  1261. mutation_params += ", $templateVariableValues: JSONString"
  1262. mutation_input += ", templateVariableValues: $templateVariableValues"
  1263. mutation = gql(
  1264. f"""
  1265. mutation pushToRunQueueByName(
  1266. {mutation_params}
  1267. ) {{
  1268. pushToRunQueueByName(
  1269. input: {{
  1270. {mutation_input}
  1271. }}
  1272. ) {{
  1273. runQueueItemId
  1274. runSpec
  1275. }}
  1276. }}
  1277. """
  1278. )
  1279. try:
  1280. result: dict[str, Any] | None = self.gql(
  1281. mutation, variables, check_retry_fn=util.no_retry_4xx
  1282. ).get("pushToRunQueueByName")
  1283. if not result:
  1284. return None
  1285. if result.get("runSpec"):
  1286. run_spec = json.loads(str(result["runSpec"]))
  1287. result["runSpec"] = run_spec
  1288. return result
  1289. except Exception as e:
  1290. if (
  1291. 'Cannot query field "runSpec" on type "PushToRunQueueByNamePayload"'
  1292. not in str(e)
  1293. ):
  1294. return None
  1295. mutation_no_runspec = gql(
  1296. """
  1297. mutation pushToRunQueueByName(
  1298. $entityName: String!,
  1299. $projectName: String!,
  1300. $queueName: String!,
  1301. $runSpec: JSONString!,
  1302. ) {
  1303. pushToRunQueueByName(
  1304. input: {
  1305. entityName: $entityName,
  1306. projectName: $projectName,
  1307. queueName: $queueName,
  1308. runSpec: $runSpec
  1309. }
  1310. ) {
  1311. runQueueItemId
  1312. }
  1313. }
  1314. """
  1315. )
  1316. try:
  1317. result = self.gql(
  1318. mutation_no_runspec, variables, check_retry_fn=util.no_retry_4xx
  1319. ).get("pushToRunQueueByName")
  1320. except Exception:
  1321. result = None
  1322. return result
  1323. @normalize_exceptions
  1324. def push_to_run_queue(
  1325. self,
  1326. queue_name: str,
  1327. launch_spec: dict[str, str],
  1328. template_variables: dict | None,
  1329. project_queue: str,
  1330. priority: int | None = None,
  1331. ) -> dict[str, Any] | None:
  1332. entity = launch_spec.get("queue_entity") or launch_spec["entity"]
  1333. run_spec = json.dumps(launch_spec)
  1334. push_result = self.push_to_run_queue_by_name(
  1335. entity, project_queue, queue_name, run_spec, template_variables, priority
  1336. )
  1337. if push_result:
  1338. return push_result
  1339. if priority is not None:
  1340. # Cannot proceed with legacy method if priority is set
  1341. return None
  1342. """ Legacy Method """
  1343. queues_found = self.get_project_run_queues(entity, project_queue)
  1344. matching_queues = [
  1345. q
  1346. for q in queues_found
  1347. if q["name"] == queue_name
  1348. # ensure user has access to queue
  1349. and (
  1350. # TODO: User created queues in the UI have USER access
  1351. q["access"] in ["PROJECT", "USER"]
  1352. or q["createdBy"] == self.default_entity
  1353. )
  1354. ]
  1355. if not matching_queues:
  1356. # in the case of a missing default queue. create it
  1357. if queue_name == "default":
  1358. wandb.termlog(
  1359. f"No default queue existing for entity: {entity} in project: {project_queue}, creating one."
  1360. )
  1361. res = self.create_run_queue(
  1362. launch_spec["entity"],
  1363. project_queue,
  1364. queue_name,
  1365. access="PROJECT",
  1366. )
  1367. if res is None or res.get("queueID") is None:
  1368. wandb.termerror(
  1369. f"Unable to create default queue for entity: {entity} on project: {project_queue}. Run could not be added to a queue"
  1370. )
  1371. return None
  1372. queue_id = res["queueID"]
  1373. else:
  1374. if project_queue == "model-registry":
  1375. _msg = f"Unable to push to run queue {queue_name}. Queue not found."
  1376. else:
  1377. _msg = f"Unable to push to run queue {project_queue}/{queue_name}. Queue not found."
  1378. wandb.termwarn(_msg)
  1379. return None
  1380. elif len(matching_queues) > 1:
  1381. wandb.termerror(
  1382. f"Unable to push to run queue {queue_name}. More than one queue found with this name."
  1383. )
  1384. return None
  1385. else:
  1386. queue_id = matching_queues[0]["id"]
  1387. spec_json = json.dumps(launch_spec)
  1388. variables = {"queueID": queue_id, "runSpec": spec_json}
  1389. mutation_params = """
  1390. $queueID: ID!,
  1391. $runSpec: JSONString!
  1392. """
  1393. mutation_input = """
  1394. queueID: $queueID,
  1395. runSpec: $runSpec
  1396. """
  1397. if template_variables is not None:
  1398. mutation_params += ", $templateVariableValues: JSONString"
  1399. mutation_input += ", templateVariableValues: $templateVariableValues"
  1400. variables.update({"templateVariableValues": json.dumps(template_variables)})
  1401. mutation = gql(
  1402. f"""
  1403. mutation pushToRunQueue(
  1404. {mutation_params}
  1405. ) {{
  1406. pushToRunQueue(
  1407. input: {{{mutation_input}}}
  1408. ) {{
  1409. runQueueItemId
  1410. }}
  1411. }}
  1412. """
  1413. )
  1414. response = self.gql(mutation, variable_values=variables)
  1415. if not response.get("pushToRunQueue"):
  1416. raise CommError(f"Error pushing run queue item to queue {queue_name}.")
  1417. result: dict[str, Any] | None = response["pushToRunQueue"]
  1418. return result
  1419. @normalize_exceptions
  1420. def pop_from_run_queue(
  1421. self,
  1422. queue_name: str,
  1423. entity: str | None = None,
  1424. project: str | None = None,
  1425. agent_id: str | None = None,
  1426. ) -> dict[str, Any] | None:
  1427. mutation = gql(
  1428. """
  1429. mutation popFromRunQueue($entity: String!, $project: String!, $queueName: String!, $launchAgentId: ID) {
  1430. popFromRunQueue(input: {
  1431. entityName: $entity,
  1432. projectName: $project,
  1433. queueName: $queueName,
  1434. launchAgentId: $launchAgentId
  1435. }) {
  1436. runQueueItemId
  1437. runSpec
  1438. }
  1439. }
  1440. """
  1441. )
  1442. response = self.gql(
  1443. mutation,
  1444. variable_values={
  1445. "entity": entity,
  1446. "project": project,
  1447. "queueName": queue_name,
  1448. "launchAgentId": agent_id,
  1449. },
  1450. )
  1451. result: dict[str, Any] | None = response["popFromRunQueue"]
  1452. return result
  1453. @normalize_exceptions
  1454. def ack_run_queue_item(self, item_id: str, run_id: str | None = None) -> bool:
  1455. mutation = gql(
  1456. """
  1457. mutation ackRunQueueItem($itemId: ID!, $runId: String!) {
  1458. ackRunQueueItem(input: { runQueueItemId: $itemId, runName: $runId }) {
  1459. success
  1460. }
  1461. }
  1462. """
  1463. )
  1464. response = self.gql(
  1465. mutation, variable_values={"itemId": item_id, "runId": str(run_id)}
  1466. )
  1467. if not response["ackRunQueueItem"]["success"]:
  1468. raise CommError(
  1469. "Error acking run queue item. Item may have already been acknowledged by another process"
  1470. )
  1471. result: bool = response["ackRunQueueItem"]["success"]
  1472. return result
  1473. @normalize_exceptions
  1474. def create_launch_agent(
  1475. self,
  1476. entity: str,
  1477. project: str,
  1478. queues: list[str],
  1479. agent_config: dict[str, Any],
  1480. version: str,
  1481. ) -> dict:
  1482. project_queues = self.get_project_run_queues(entity, project)
  1483. if not project_queues:
  1484. # create default queue if it doesn't already exist
  1485. default = self.create_run_queue(
  1486. entity, project, "default", access="PROJECT"
  1487. )
  1488. if default is None or default.get("queueID") is None:
  1489. raise CommError(
  1490. f"Unable to create default queue for {entity}/{project}. No queues for agent to poll"
  1491. )
  1492. project_queues = [{"id": default["queueID"], "name": "default"}]
  1493. polling_queue_ids = [
  1494. q["id"] for q in project_queues if q["name"] in queues
  1495. ] # filter to poll specified queues
  1496. if len(polling_queue_ids) != len(queues):
  1497. raise CommError(
  1498. f"Could not start launch agent: Not all of requested queues ({', '.join(queues)}) found. "
  1499. f"Available queues for this project: {','.join([q['name'] for q in project_queues])}"
  1500. )
  1501. hostname = socket.gethostname()
  1502. variable_values = {
  1503. "entity": entity,
  1504. "project": project,
  1505. "queues": polling_queue_ids,
  1506. "hostname": hostname,
  1507. "agentConfig": json.dumps(agent_config),
  1508. "version": version,
  1509. }
  1510. mutation_params = """
  1511. $entity: String!,
  1512. $project: String!,
  1513. $queues: [ID!]!,
  1514. $hostname: String!,
  1515. $agentConfig: JSONString,
  1516. $version: String
  1517. """
  1518. mutation_input = """
  1519. entityName: $entity,
  1520. projectName: $project,
  1521. runQueues: $queues,
  1522. hostname: $hostname,
  1523. agentConfig: $agentConfig,
  1524. version: $version
  1525. """
  1526. mutation = gql(
  1527. f"""
  1528. mutation createLaunchAgent(
  1529. {mutation_params}
  1530. ) {{
  1531. createLaunchAgent(
  1532. input: {{
  1533. {mutation_input}
  1534. }}
  1535. ) {{
  1536. launchAgentId
  1537. }}
  1538. }}
  1539. """
  1540. )
  1541. result: dict = self.gql(mutation, variable_values)["createLaunchAgent"]
  1542. return result
  1543. @normalize_exceptions
  1544. def update_launch_agent_status(
  1545. self,
  1546. agent_id: str,
  1547. status: str,
  1548. ) -> dict:
  1549. mutation = gql(
  1550. """
  1551. mutation updateLaunchAgent($agentId: ID!, $agentStatus: String){
  1552. updateLaunchAgent(
  1553. input: {
  1554. launchAgentId: $agentId
  1555. agentStatus: $agentStatus
  1556. }
  1557. ) {
  1558. success
  1559. }
  1560. }
  1561. """
  1562. )
  1563. variable_values = {
  1564. "agentId": agent_id,
  1565. "agentStatus": status,
  1566. }
  1567. result: dict = self.gql(mutation, variable_values)["updateLaunchAgent"]
  1568. return result
  1569. @normalize_exceptions
  1570. def get_launch_agent(self, agent_id: str) -> dict:
  1571. query = gql(
  1572. """
  1573. query LaunchAgent($agentId: ID!) {
  1574. launchAgent(id: $agentId) {
  1575. id
  1576. name
  1577. runQueues
  1578. hostname
  1579. agentStatus
  1580. stopPolling
  1581. heartbeatAt
  1582. }
  1583. }
  1584. """
  1585. )
  1586. variable_values = {
  1587. "agentId": agent_id,
  1588. }
  1589. result: dict = self.gql(query, variable_values)["launchAgent"]
  1590. return result
  1591. @normalize_exceptions
  1592. def upsert_run(
  1593. self,
  1594. id: str | None = None,
  1595. name: str | None = None,
  1596. project: str | None = None,
  1597. host: str | None = None,
  1598. group: str | None = None,
  1599. tags: list[str] | None = None,
  1600. config: dict | None = None,
  1601. description: str | None = None,
  1602. entity: str | None = None,
  1603. state: str | None = None,
  1604. display_name: str | None = None,
  1605. notes: str | None = None,
  1606. repo: str | None = None,
  1607. job_type: str | None = None,
  1608. program_path: str | None = None,
  1609. commit: str | None = None,
  1610. sweep_name: str | None = None,
  1611. summary_metrics: str | None = None,
  1612. num_retries: int | None = None,
  1613. ) -> tuple[dict, bool]:
  1614. """Update a run.
  1615. Args:
  1616. id (str, optional): The existing run to update
  1617. name (str, optional): The name of the run to create
  1618. group (str, optional): Name of the group this run is a part of
  1619. project (str, optional): The name of the project
  1620. host (str, optional): The name of the host
  1621. tags (list, optional): A list of tags to apply to the run
  1622. config (dict, optional): The latest config params
  1623. description (str, optional): A description of this project
  1624. entity (str, optional): The entity to scope this project to.
  1625. display_name (str, optional): The display name of this project
  1626. notes (str, optional): Notes about this run
  1627. repo (str, optional): Url of the program's repository.
  1628. state (str, optional): State of the program.
  1629. job_type (str, optional): Type of job, e.g 'train'.
  1630. program_path (str, optional): Path to the program.
  1631. commit (str, optional): The Git SHA to associate the run with
  1632. sweep_name (str, optional): The name of the sweep this run is a part of
  1633. summary_metrics (str, optional): The JSON summary metrics
  1634. num_retries (int, optional): Number of retries
  1635. """
  1636. query_string = """
  1637. mutation UpsertBucket(
  1638. $id: String,
  1639. $name: String,
  1640. $project: String,
  1641. $entity: String,
  1642. $groupName: String,
  1643. $description: String,
  1644. $displayName: String,
  1645. $notes: String,
  1646. $commit: String,
  1647. $config: JSONString,
  1648. $host: String,
  1649. $debug: Boolean,
  1650. $program: String,
  1651. $repo: String,
  1652. $jobType: String,
  1653. $state: String,
  1654. $sweep: String,
  1655. $tags: [String!],
  1656. $summaryMetrics: JSONString,
  1657. ) {
  1658. upsertBucket(input: {
  1659. id: $id,
  1660. name: $name,
  1661. groupName: $groupName,
  1662. modelName: $project,
  1663. entityName: $entity,
  1664. description: $description,
  1665. displayName: $displayName,
  1666. notes: $notes,
  1667. config: $config,
  1668. commit: $commit,
  1669. host: $host,
  1670. debug: $debug,
  1671. jobProgram: $program,
  1672. jobRepo: $repo,
  1673. jobType: $jobType,
  1674. state: $state,
  1675. sweep: $sweep,
  1676. tags: $tags,
  1677. summaryMetrics: $summaryMetrics,
  1678. }) {
  1679. bucket {
  1680. id
  1681. name
  1682. displayName
  1683. description
  1684. config
  1685. sweepName
  1686. project {
  1687. id
  1688. name
  1689. entity {
  1690. id
  1691. name
  1692. }
  1693. }
  1694. historyLineCount
  1695. }
  1696. inserted
  1697. }
  1698. }
  1699. """
  1700. mutation = gql(query_string)
  1701. config_str = json.dumps(config) if config else None
  1702. if not description or description.isspace():
  1703. description = None
  1704. kwargs = {}
  1705. if num_retries is not None:
  1706. kwargs["num_retries"] = num_retries
  1707. variable_values = {
  1708. "id": id,
  1709. "entity": entity or self.settings("entity"),
  1710. "name": name,
  1711. "project": project or util.auto_project_name(program_path),
  1712. "groupName": group,
  1713. "tags": tags,
  1714. "description": description,
  1715. "config": config_str,
  1716. "commit": commit,
  1717. "displayName": display_name,
  1718. "notes": notes,
  1719. "host": None
  1720. if self.settings().get("anonymous") in ["allow", "must"]
  1721. else host,
  1722. "debug": env.is_debug(env=self._environ),
  1723. "repo": repo,
  1724. "program": program_path,
  1725. "jobType": job_type,
  1726. "state": state,
  1727. "sweep": sweep_name,
  1728. "summaryMetrics": summary_metrics,
  1729. }
  1730. # retry conflict errors for 2 minutes, default to no_auth_retry
  1731. check_retry_fn = util.make_check_retry_fn(
  1732. check_fn=util.check_retry_conflict_or_gone,
  1733. check_timedelta=datetime.timedelta(minutes=2),
  1734. fallback_retry_fn=util.no_retry_auth,
  1735. )
  1736. response = self.gql(
  1737. mutation,
  1738. variable_values=variable_values,
  1739. check_retry_fn=check_retry_fn,
  1740. **kwargs,
  1741. )
  1742. run_obj: dict[str, dict[str, dict[str, str]]] = response["upsertBucket"][
  1743. "bucket"
  1744. ]
  1745. project_obj: dict[str, dict[str, str]] = run_obj.get("project", {})
  1746. if project_obj:
  1747. self.set_setting("project", project_obj["name"])
  1748. entity_obj = project_obj.get("entity", {})
  1749. if entity_obj:
  1750. self.set_setting("entity", entity_obj["name"])
  1751. return (
  1752. response["upsertBucket"]["bucket"],
  1753. response["upsertBucket"]["inserted"],
  1754. )
  1755. @normalize_exceptions
  1756. def rewind_run(
  1757. self,
  1758. run_name: str,
  1759. metric_name: str,
  1760. metric_value: float,
  1761. program_path: str | None = None,
  1762. entity: str | None = None,
  1763. project: str | None = None,
  1764. num_retries: int | None = None,
  1765. ) -> dict:
  1766. """Rewinds a run to a previous state.
  1767. Args:
  1768. run_name (str): The name of the run to rewind
  1769. metric_name (str): The name of the metric to rewind to
  1770. metric_value (float): The value of the metric to rewind to
  1771. program_path (str, optional): Path to the program
  1772. entity (str, optional): The entity to scope this project to
  1773. project (str, optional): The name of the project
  1774. num_retries (int, optional): Number of retries
  1775. Returns:
  1776. A dict with the rewound run
  1777. {
  1778. "id": "run_id",
  1779. "name": "run_name",
  1780. "displayName": "run_display_name",
  1781. "description": "run_description",
  1782. "config": "stringified_run_config_json",
  1783. "sweepName": "run_sweep_name",
  1784. "project": {
  1785. "id": "project_id",
  1786. "name": "project_name",
  1787. "entity": {
  1788. "id": "entity_id",
  1789. "name": "entity_name"
  1790. }
  1791. },
  1792. "historyLineCount": 100,
  1793. }
  1794. """
  1795. query_string = """
  1796. mutation RewindRun($runName: String!, $entity: String, $project: String, $metricName: String!, $metricValue: Float!) {
  1797. rewindRun(input: {runName: $runName, entityName: $entity, projectName: $project, metricName: $metricName, metricValue: $metricValue}) {
  1798. rewoundRun {
  1799. id
  1800. name
  1801. displayName
  1802. description
  1803. config
  1804. sweepName
  1805. project {
  1806. id
  1807. name
  1808. entity {
  1809. id
  1810. name
  1811. }
  1812. }
  1813. historyLineCount
  1814. }
  1815. }
  1816. }
  1817. """
  1818. mutation = gql(query_string)
  1819. kwargs = {}
  1820. if num_retries is not None:
  1821. kwargs["num_retries"] = num_retries
  1822. variable_values = {
  1823. "runName": run_name,
  1824. "entity": entity or self.settings("entity"),
  1825. "project": project or util.auto_project_name(program_path),
  1826. "metricName": metric_name,
  1827. "metricValue": metric_value,
  1828. }
  1829. # retry conflict errors for 2 minutes, default to no_auth_retry
  1830. check_retry_fn = util.make_check_retry_fn(
  1831. check_fn=util.check_retry_conflict_or_gone,
  1832. check_timedelta=datetime.timedelta(minutes=2),
  1833. fallback_retry_fn=util.no_retry_auth,
  1834. )
  1835. response = self.gql(
  1836. mutation,
  1837. variable_values=variable_values,
  1838. check_retry_fn=check_retry_fn,
  1839. **kwargs,
  1840. )
  1841. run_obj: dict[str, dict[str, dict[str, str]]] = response.get(
  1842. "rewindRun", {}
  1843. ).get("rewoundRun", {})
  1844. project_obj: dict[str, dict[str, str]] = run_obj.get("project", {})
  1845. if project_obj:
  1846. self.set_setting("project", project_obj["name"])
  1847. entity_obj = project_obj.get("entity", {})
  1848. if entity_obj:
  1849. self.set_setting("entity", entity_obj["name"])
  1850. return run_obj
  1851. @normalize_exceptions
  1852. def get_run_info(
  1853. self,
  1854. entity: str,
  1855. project: str,
  1856. name: str,
  1857. ) -> dict:
  1858. query = gql(
  1859. """
  1860. query RunInfo($project: String!, $entity: String!, $name: String!) {
  1861. project(name: $project, entityName: $entity) {
  1862. run(name: $name) {
  1863. runInfo {
  1864. program
  1865. args
  1866. os
  1867. python
  1868. colab
  1869. executable
  1870. codeSaved
  1871. cpuCount
  1872. gpuCount
  1873. gpu
  1874. git {
  1875. remote
  1876. commit
  1877. }
  1878. }
  1879. }
  1880. }
  1881. }
  1882. """
  1883. )
  1884. variable_values = {"project": project, "entity": entity, "name": name}
  1885. res = self.gql(query, variable_values)
  1886. if res.get("project") is None:
  1887. raise CommError(
  1888. f"Error fetching run info for {entity}/{project}/{name}. Check that this project exists and you have access to this entity and project"
  1889. )
  1890. elif res["project"].get("run") is None:
  1891. raise CommError(
  1892. f"Error fetching run info for {entity}/{project}/{name}. Check that this run id exists"
  1893. )
  1894. run_info: dict = res["project"]["run"]["runInfo"]
  1895. return run_info
  1896. @normalize_exceptions
  1897. def get_run_state(self, entity: str, project: str, name: str) -> str:
  1898. query = gql(
  1899. """
  1900. query RunState(
  1901. $project: String!,
  1902. $entity: String!,
  1903. $name: String!) {
  1904. project(name: $project, entityName: $entity) {
  1905. run(name: $name) {
  1906. state
  1907. }
  1908. }
  1909. }
  1910. """
  1911. )
  1912. variable_values = {
  1913. "project": project,
  1914. "entity": entity,
  1915. "name": name,
  1916. }
  1917. res = self.gql(query, variable_values)
  1918. if res.get("project") is None or res["project"].get("run") is None:
  1919. raise CommError(f"Error fetching run state for {entity}/{project}/{name}.")
  1920. run_state: str = res["project"]["run"]["state"]
  1921. return run_state
  1922. @normalize_exceptions
  1923. def upload_urls(
  1924. self,
  1925. project: str,
  1926. files: list[str] | dict[str, IO],
  1927. run: str | None = None,
  1928. entity: str | None = None,
  1929. description: str | None = None,
  1930. ) -> tuple[str, list[str], dict[str, dict[str, Any]]]:
  1931. """Generate temporary resumable upload urls.
  1932. Args:
  1933. project (str): The project to download
  1934. files (list or dict): The filenames to upload
  1935. run (str, optional): The run to upload to
  1936. entity (str, optional): The entity to scope this project to.
  1937. description (str, optional): description
  1938. Returns:
  1939. (run_id, upload_headers, file_info)
  1940. run_id: id of run we uploaded files to
  1941. upload_headers: A list of headers to use when uploading files.
  1942. file_info: A dict of filenames and urls.
  1943. {
  1944. "run_id": "run_id",
  1945. "upload_headers": [""],
  1946. "file_info": [
  1947. { "weights.h5": { "uploadUrl": "https://weights.url" } },
  1948. { "model.json": { "uploadUrl": "https://model.json" } }
  1949. ]
  1950. }
  1951. """
  1952. run_name = run or self.current_run_id
  1953. assert run_name, "run must be specified"
  1954. entity = entity or self.settings("entity")
  1955. assert entity, "entity must be specified"
  1956. query = gql(
  1957. """
  1958. mutation CreateRunFiles($entity: String!, $project: String!, $run: String!, $files: [String!]!) {
  1959. createRunFiles(input: {entityName: $entity, projectName: $project, runName: $run, files: $files}) {
  1960. runID
  1961. uploadHeaders
  1962. files {
  1963. name
  1964. uploadUrl
  1965. }
  1966. }
  1967. }
  1968. """
  1969. )
  1970. query_result = self.gql(
  1971. query,
  1972. variable_values={
  1973. "project": project,
  1974. "run": run_name,
  1975. "entity": entity,
  1976. "files": [file for file in files],
  1977. },
  1978. )
  1979. result = query_result["createRunFiles"]
  1980. run_id = result["runID"]
  1981. if not run_id:
  1982. raise CommError(
  1983. f"Error uploading files to {entity}/{project}/{run_name}. Check that this project exists and you have access to this entity and project"
  1984. )
  1985. file_name_urls = {file["name"]: file for file in result["files"]}
  1986. return run_id, result["uploadHeaders"], file_name_urls
  1987. def legacy_upload_urls(
  1988. self,
  1989. project: str,
  1990. files: list[str] | dict[str, IO],
  1991. run: str | None = None,
  1992. entity: str | None = None,
  1993. description: str | None = None,
  1994. ) -> tuple[str, list[str], dict[str, dict[str, Any]]]:
  1995. """Generate temporary resumable upload urls.
  1996. A new mutation createRunFiles was introduced after 0.15.4.
  1997. This function is used to support older versions.
  1998. """
  1999. query = gql(
  2000. """
  2001. query RunUploadUrls($name: String!, $files: [String]!, $entity: String, $run: String!, $description: String) {
  2002. model(name: $name, entityName: $entity) {
  2003. bucket(name: $run, desc: $description) {
  2004. id
  2005. files(names: $files) {
  2006. uploadHeaders
  2007. edges {
  2008. node {
  2009. name
  2010. url(upload: true)
  2011. updatedAt
  2012. }
  2013. }
  2014. }
  2015. }
  2016. }
  2017. }
  2018. """
  2019. )
  2020. run_id = run or self.current_run_id
  2021. assert run_id, "run must be specified"
  2022. entity = entity or self.settings("entity")
  2023. query_result = self.gql(
  2024. query,
  2025. variable_values={
  2026. "name": project,
  2027. "run": run_id,
  2028. "entity": entity,
  2029. "files": [file for file in files],
  2030. "description": description,
  2031. },
  2032. )
  2033. run_obj = query_result["model"]["bucket"]
  2034. if run_obj:
  2035. for file_node in run_obj["files"]["edges"]:
  2036. file = file_node["node"]
  2037. # we previously used "url" field but now use "uploadUrl"
  2038. # replace the "url" field with "uploadUrl for downstream compatibility
  2039. if "url" in file and "uploadUrl" not in file:
  2040. file["uploadUrl"] = file.pop("url")
  2041. result = {
  2042. file["name"]: file for file in self._flatten_edges(run_obj["files"])
  2043. }
  2044. return run_obj["id"], run_obj["files"]["uploadHeaders"], result
  2045. else:
  2046. raise CommError(f"Run does not exist {entity}/{project}/{run_id}.")
  2047. @normalize_exceptions
  2048. def download_urls(
  2049. self,
  2050. project: str,
  2051. run: str | None = None,
  2052. entity: str | None = None,
  2053. ) -> dict[str, dict[str, str]]:
  2054. """Generate download urls.
  2055. Args:
  2056. project (str): The project to download
  2057. run (str): The run to upload to
  2058. entity (str, optional): The entity to scope this project to. Defaults to wandb models
  2059. Returns:
  2060. A dict of extensions and urls
  2061. {
  2062. 'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
  2063. 'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
  2064. }
  2065. """
  2066. query = gql(
  2067. """
  2068. query RunDownloadUrls($name: String!, $entity: String, $run: String!) {
  2069. model(name: $name, entityName: $entity) {
  2070. bucket(name: $run) {
  2071. files {
  2072. edges {
  2073. node {
  2074. name
  2075. url
  2076. md5
  2077. updatedAt
  2078. }
  2079. }
  2080. }
  2081. }
  2082. }
  2083. }
  2084. """
  2085. )
  2086. run = run or self.current_run_id
  2087. assert run, "run must be specified"
  2088. entity = entity or self.settings("entity")
  2089. query_result = self.gql(
  2090. query,
  2091. variable_values={
  2092. "name": project,
  2093. "run": run,
  2094. "entity": entity,
  2095. },
  2096. )
  2097. if query_result["model"] is None:
  2098. raise CommError(f"Run does not exist {entity}/{project}/{run}.")
  2099. files = self._flatten_edges(query_result["model"]["bucket"]["files"])
  2100. return {file["name"]: file for file in files if file}
  2101. @normalize_exceptions
  2102. def download_url(
  2103. self,
  2104. project: str,
  2105. file_name: str,
  2106. run: str | None = None,
  2107. entity: str | None = None,
  2108. ) -> dict[str, str] | None:
  2109. """Generate download urls.
  2110. Args:
  2111. project (str): The project to download
  2112. file_name (str): The name of the file to download
  2113. run (str): The run to upload to
  2114. entity (str, optional): The entity to scope this project to. Defaults to wandb models
  2115. Returns:
  2116. A dict of extensions and urls
  2117. { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
  2118. """
  2119. query = gql(
  2120. """
  2121. query RunDownloadUrl($name: String!, $fileName: String!, $entity: String, $run: String!) {
  2122. model(name: $name, entityName: $entity) {
  2123. bucket(name: $run) {
  2124. files(names: [$fileName]) {
  2125. edges {
  2126. node {
  2127. name
  2128. url
  2129. md5
  2130. updatedAt
  2131. }
  2132. }
  2133. }
  2134. }
  2135. }
  2136. }
  2137. """
  2138. )
  2139. run = run or self.current_run_id
  2140. assert run, "run must be specified"
  2141. query_result = self.gql(
  2142. query,
  2143. variable_values={
  2144. "name": project,
  2145. "run": run,
  2146. "fileName": file_name,
  2147. "entity": entity or self.settings("entity"),
  2148. },
  2149. )
  2150. if query_result["model"]:
  2151. files = self._flatten_edges(query_result["model"]["bucket"]["files"])
  2152. return files[0] if len(files) > 0 and files[0].get("updatedAt") else None
  2153. else:
  2154. return None
  2155. @normalize_exceptions
  2156. def download_file(self, url: str) -> tuple[int, requests.Response]:
  2157. """Initiate a streaming download.
  2158. Args:
  2159. url (str): The url to download
  2160. Returns:
  2161. A tuple of the content length and the streaming response
  2162. """
  2163. import requests
  2164. check_httpclient_logger_handler()
  2165. http_headers = {}
  2166. auth = None
  2167. if self.access_token is not None:
  2168. http_headers["Authorization"] = f"Bearer {self.access_token}"
  2169. else:
  2170. auth = ("api", self.api_key or "")
  2171. response = requests.get(
  2172. url,
  2173. auth=auth,
  2174. headers=http_headers,
  2175. stream=True,
  2176. )
  2177. response.raise_for_status()
  2178. return int(response.headers.get("content-length", 0)), response
  2179. @normalize_exceptions
  2180. def download_write_file(
  2181. self,
  2182. metadata: dict[str, str],
  2183. out_dir: str | None = None,
  2184. ) -> tuple[str, requests.Response | None]:
  2185. """Download a file from a run and write it to wandb/.
  2186. Args:
  2187. metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().
  2188. out_dir (str, optional): The directory to write the file to. Defaults to wandb/
  2189. Returns:
  2190. A tuple of the file's local path and the streaming response. The streaming response is None if the file
  2191. already existed and was up-to-date.
  2192. """
  2193. filename = metadata["name"]
  2194. path = os.path.join(out_dir or self.settings("wandb_dir"), filename)
  2195. if self.file_current(filename, B64MD5(metadata["md5"])):
  2196. return path, None
  2197. size, response = self.download_file(metadata["url"])
  2198. with util.fsync_open(path, "wb") as file:
  2199. for data in response.iter_content(chunk_size=1024):
  2200. file.write(data)
  2201. return path, response
  2202. def upload_file_azure(
  2203. self, url: str, file: Any, extra_headers: dict[str, str]
  2204. ) -> None:
  2205. """Upload a file to azure."""
  2206. import requests
  2207. from azure.core.exceptions import AzureError # type: ignore
  2208. # Configure the client without retries so our existing logic can handle them
  2209. client = self._azure_blob_module.BlobClient.from_blob_url(
  2210. url, retry_policy=self._azure_blob_module.LinearRetry(retry_total=0)
  2211. )
  2212. try:
  2213. if extra_headers.get("Content-MD5") is not None:
  2214. md5: bytes | None = base64.b64decode(extra_headers["Content-MD5"])
  2215. else:
  2216. md5 = None
  2217. content_settings = self._azure_blob_module.ContentSettings(
  2218. content_md5=md5,
  2219. content_type=extra_headers.get("Content-Type"),
  2220. )
  2221. client.upload_blob(
  2222. file,
  2223. max_concurrency=4,
  2224. length=len(file),
  2225. overwrite=True,
  2226. content_settings=content_settings,
  2227. )
  2228. except AzureError as e:
  2229. if hasattr(e, "response"):
  2230. response = requests.models.Response()
  2231. response.status_code = e.response.status_code
  2232. response.headers = e.response.headers
  2233. raise requests.exceptions.RequestException(e.message, response=response)
  2234. else:
  2235. raise requests.exceptions.ConnectionError(e.message)
  2236. def upload_multipart_file_chunk(
  2237. self,
  2238. url: str,
  2239. upload_chunk: bytes,
  2240. extra_headers: dict[str, str] | None = None,
  2241. ) -> requests.Response | None:
  2242. """Upload a file chunk to S3 with failure resumption.
  2243. Args:
  2244. url: The url to download
  2245. upload_chunk: The path to the file you want to upload
  2246. extra_headers: A dictionary of extra headers to send with the request
  2247. Returns:
  2248. The `requests` library response object
  2249. """
  2250. import requests
  2251. check_httpclient_logger_handler()
  2252. try:
  2253. if env.is_debug(env=self._environ):
  2254. logger.debug("upload_file: %s", url)
  2255. response = self._upload_file_session.put(
  2256. url, data=upload_chunk, headers=extra_headers
  2257. )
  2258. if env.is_debug(env=self._environ):
  2259. logger.debug("upload_file: %s complete", url)
  2260. response.raise_for_status()
  2261. except requests.exceptions.RequestException as e:
  2262. logger.exception(f"upload_file exception for {url=}")
  2263. response_content = e.response.content if e.response is not None else ""
  2264. status_code = e.response.status_code if e.response is not None else 0
  2265. # S3 reports retryable request timeouts out-of-band
  2266. is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
  2267. response_content
  2268. )
  2269. # Retry errors from cloud storage or local network issues
  2270. if (
  2271. status_code in (308, 408, 409, 429, 500, 502, 503, 504)
  2272. or isinstance(
  2273. e,
  2274. (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
  2275. )
  2276. or is_aws_retryable
  2277. ):
  2278. _e = retry.TransientError(exc=e)
  2279. raise _e.with_traceback(sys.exc_info()[2])
  2280. else:
  2281. get_sentry().reraise(e)
  2282. return response
  2283. def upload_file(
  2284. self,
  2285. url: str,
  2286. file: IO[bytes],
  2287. callback: ProgressFn | None = None,
  2288. extra_headers: dict[str, str] | None = None,
  2289. ) -> requests.Response | None:
  2290. """Upload a file to W&B with failure resumption.
  2291. Args:
  2292. url: The url to download
  2293. file: The path to the file you want to upload
  2294. callback: A callback which is passed the number of
  2295. bytes uploaded since the last time it was called, used to report progress
  2296. extra_headers: A dictionary of extra headers to send with the request
  2297. Returns:
  2298. The `requests` library response object
  2299. """
  2300. import requests
  2301. check_httpclient_logger_handler()
  2302. extra_headers = extra_headers.copy() if extra_headers else {}
  2303. response: requests.Response | None = None
  2304. progress = Progress(file, callback=callback)
  2305. try:
  2306. if "x-ms-blob-type" in extra_headers and self._azure_blob_module:
  2307. self.upload_file_azure(url, progress, extra_headers)
  2308. else:
  2309. if "x-ms-blob-type" in extra_headers:
  2310. wandb.termwarn(
  2311. "Azure uploads over 256MB require the azure SDK, install with pip install wandb[azure]",
  2312. repeat=False,
  2313. )
  2314. if env.is_debug(env=self._environ):
  2315. logger.debug("upload_file: %s", url)
  2316. response = self._upload_file_session.put(
  2317. url, data=progress, headers=extra_headers
  2318. )
  2319. if env.is_debug(env=self._environ):
  2320. logger.debug("upload_file: %s complete", url)
  2321. response.raise_for_status()
  2322. except requests.exceptions.RequestException as e:
  2323. logger.exception(f"upload_file exception for {url=}")
  2324. response_content = e.response.content if e.response is not None else ""
  2325. status_code = e.response.status_code if e.response is not None else 0
  2326. # S3 reports retryable request timeouts out-of-band
  2327. is_aws_retryable = (
  2328. "x-amz-meta-md5" in extra_headers
  2329. and status_code == 400
  2330. and "RequestTimeout" in str(response_content)
  2331. )
  2332. # We need to rewind the file for the next retry (the file passed in is `seek`'ed to 0)
  2333. progress.rewind()
  2334. # Retry errors from cloud storage or local network issues
  2335. if (
  2336. status_code in (308, 408, 409, 429, 500, 502, 503, 504)
  2337. or isinstance(
  2338. e,
  2339. (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
  2340. )
  2341. or is_aws_retryable
  2342. ):
  2343. _e = retry.TransientError(exc=e)
  2344. raise _e.with_traceback(sys.exc_info()[2])
  2345. else:
  2346. get_sentry().reraise(e)
  2347. return response
  2348. @normalize_exceptions
  2349. def register_agent(
  2350. self,
  2351. host: str,
  2352. sweep_id: str | None = None,
  2353. project_name: str | None = None,
  2354. entity: str | None = None,
  2355. ) -> dict:
  2356. """Register a new agent.
  2357. Args:
  2358. host (str): hostname
  2359. sweep_id (str): sweep id
  2360. project_name: (str): model that contains sweep
  2361. entity: (str): entity that contains sweep
  2362. """
  2363. mutation = gql(
  2364. """
  2365. mutation CreateAgent(
  2366. $host: String!
  2367. $projectName: String,
  2368. $entityName: String,
  2369. $sweep: String!
  2370. ) {
  2371. createAgent(input: {
  2372. host: $host,
  2373. projectName: $projectName,
  2374. entityName: $entityName,
  2375. sweep: $sweep,
  2376. }) {
  2377. agent {
  2378. id
  2379. }
  2380. }
  2381. }
  2382. """
  2383. )
  2384. if entity is None:
  2385. entity = self.settings("entity")
  2386. if project_name is None:
  2387. project_name = self.settings("project")
  2388. response = self.gql(
  2389. mutation,
  2390. variable_values={
  2391. "host": host,
  2392. "entityName": entity,
  2393. "projectName": project_name,
  2394. "sweep": sweep_id,
  2395. },
  2396. check_retry_fn=util.no_retry_4xx,
  2397. )
  2398. result: dict = response["createAgent"]["agent"]
  2399. return result
  2400. def agent_heartbeat(
  2401. self, agent_id: str, metrics: dict, run_states: dict
  2402. ) -> list[dict[str, Any]]:
  2403. """Notify server about agent state, receive commands.
  2404. Args:
  2405. agent_id (str): agent_id
  2406. metrics (dict): system metrics
  2407. run_states (dict): run_id: state mapping
  2408. Returns:
  2409. list of commands to execute.
  2410. Raises:
  2411. SweepNotFoundError: If the server returns a 404, indicating the
  2412. sweep was likely deleted.
  2413. """
  2414. import requests
  2415. from wandb.sdk.launch.sweeps import SweepNotFoundError
  2416. mutation = gql(
  2417. """
  2418. mutation Heartbeat(
  2419. $id: ID!,
  2420. $metrics: JSONString,
  2421. $runState: JSONString
  2422. ) {
  2423. agentHeartbeat(input: {
  2424. id: $id,
  2425. metrics: $metrics,
  2426. runState: $runState
  2427. }) {
  2428. agent {
  2429. id
  2430. }
  2431. commands
  2432. }
  2433. }
  2434. """
  2435. )
  2436. if agent_id is None:
  2437. raise ValueError("Cannot call heartbeat with an unregistered agent.")
  2438. try:
  2439. response = self.gql(
  2440. mutation,
  2441. variable_values={
  2442. "id": agent_id,
  2443. "metrics": json.dumps(metrics),
  2444. "runState": json.dumps(run_states),
  2445. },
  2446. timeout=60,
  2447. )
  2448. except requests.exceptions.HTTPError as e:
  2449. if e.response is not None and e.response.status_code == 404:
  2450. raise SweepNotFoundError(
  2451. "Sweep not found. The sweep may have been deleted."
  2452. ) from e
  2453. logger.exception("Error communicating with W&B.")
  2454. return []
  2455. except Exception:
  2456. logger.exception("Error communicating with W&B.")
  2457. return []
  2458. else:
  2459. result: list[dict[str, Any]] = json.loads(
  2460. response["agentHeartbeat"]["commands"]
  2461. )
  2462. return result
  2463. @staticmethod
  2464. def _validate_config_and_fill_distribution(config: dict) -> dict:
  2465. # verify that parameters are well specified.
  2466. # TODO(dag): deprecate this in favor of jsonschema validation once
  2467. # apiVersion 2 is released and local controller is integrated with
  2468. # wandb/client.
  2469. # avoid modifying the original config dict in
  2470. # case it is reused outside the calling func
  2471. config = deepcopy(config)
  2472. # explicitly cast to dict in case config was passed as a sweepconfig
  2473. # sweepconfig does not serialize cleanly to yaml and breaks graphql,
  2474. # but it is a subclass of dict, so this conversion is clean
  2475. config = dict(config)
  2476. if "parameters" not in config:
  2477. # still shows an anaconda warning, but doesn't error
  2478. return config
  2479. for parameter_name in config["parameters"]:
  2480. parameter = config["parameters"][parameter_name]
  2481. if (
  2482. "min" in parameter
  2483. and "max" in parameter
  2484. and "distribution" not in parameter
  2485. ):
  2486. if isinstance(parameter["min"], int) and isinstance(
  2487. parameter["max"], int
  2488. ):
  2489. parameter["distribution"] = "int_uniform"
  2490. elif isinstance(parameter["min"], float) and isinstance(
  2491. parameter["max"], float
  2492. ):
  2493. parameter["distribution"] = "uniform"
  2494. else:
  2495. raise ValueError(
  2496. f"Parameter {parameter_name} is ambiguous, please specify bounds as both floats (for a float_"
  2497. "uniform distribution) or ints (for an int_uniform distribution)."
  2498. )
  2499. return config
  2500. @normalize_exceptions
  2501. def upsert_sweep(
  2502. self,
  2503. config: dict,
  2504. controller: str | None = None,
  2505. launch_scheduler: str | None = None,
  2506. scheduler: str | None = None,
  2507. obj_id: str | None = None,
  2508. project: str | None = None,
  2509. entity: str | None = None,
  2510. state: str | None = None,
  2511. prior_runs: list[str] | None = None,
  2512. display_name: str | None = None,
  2513. template_variable_values: dict[str, Any] | None = None,
  2514. ) -> tuple[str, list[str]]:
  2515. """Upsert a sweep object.
  2516. Args:
  2517. config (dict): sweep config (will be converted to yaml)
  2518. controller (str): controller to use
  2519. launch_scheduler (str): launch scheduler to use
  2520. scheduler (str): scheduler to use
  2521. obj_id (str): object id
  2522. project (str): project to use
  2523. entity (str): entity to use
  2524. state (str): state
  2525. prior_runs (list): IDs of existing runs to add to the sweep
  2526. display_name (str): display name for the sweep
  2527. template_variable_values (dict): template variable values
  2528. """
  2529. import yaml
  2530. project_query = """
  2531. project {
  2532. id
  2533. name
  2534. entity {
  2535. id
  2536. name
  2537. }
  2538. }
  2539. """
  2540. mutation_str = """
  2541. mutation UpsertSweep(
  2542. $id: ID,
  2543. $config: String,
  2544. $description: String,
  2545. $entityName: String,
  2546. $projectName: String,
  2547. $controller: JSONString,
  2548. $scheduler: JSONString,
  2549. $state: String,
  2550. $priorRunsFilters: JSONString,
  2551. $displayName: String,
  2552. ) {
  2553. upsertSweep(input: {
  2554. id: $id,
  2555. config: $config,
  2556. description: $description,
  2557. entityName: $entityName,
  2558. projectName: $projectName,
  2559. controller: $controller,
  2560. scheduler: $scheduler,
  2561. state: $state,
  2562. priorRunsFilters: $priorRunsFilters,
  2563. displayName: $displayName,
  2564. }) {
  2565. sweep {
  2566. name
  2567. _PROJECT_QUERY_
  2568. }
  2569. configValidationWarnings
  2570. }
  2571. }
  2572. """
  2573. # TODO(jhr): we need protocol versioning to know schema is not supported
  2574. # for now we will just try both new and old query
  2575. mutation_5 = gql(
  2576. mutation_str.replace(
  2577. "$controller: JSONString,",
  2578. "$controller: JSONString,$launchScheduler: JSONString, $templateVariableValues: JSONString,",
  2579. )
  2580. .replace(
  2581. "controller: $controller,",
  2582. "controller: $controller,launchScheduler: $launchScheduler,templateVariableValues: $templateVariableValues,",
  2583. )
  2584. .replace("_PROJECT_QUERY_", project_query)
  2585. )
  2586. # launchScheduler was introduced in core v0.14.0
  2587. mutation_4 = gql(
  2588. mutation_str.replace(
  2589. "$controller: JSONString,",
  2590. "$controller: JSONString,$launchScheduler: JSONString,",
  2591. )
  2592. .replace(
  2593. "controller: $controller,",
  2594. "controller: $controller,launchScheduler: $launchScheduler",
  2595. )
  2596. .replace("_PROJECT_QUERY_", project_query)
  2597. )
  2598. # mutation 3 maps to backend that can support CLI version of at least 0.10.31
  2599. mutation_3 = gql(mutation_str.replace("_PROJECT_QUERY_", project_query))
  2600. mutation_2 = gql(
  2601. mutation_str.replace("_PROJECT_QUERY_", project_query).replace(
  2602. "configValidationWarnings", ""
  2603. )
  2604. )
  2605. mutation_1 = gql(
  2606. mutation_str.replace("_PROJECT_QUERY_", "").replace(
  2607. "configValidationWarnings", ""
  2608. )
  2609. )
  2610. # TODO(dag): replace this with a query for protocol versioning
  2611. mutations = [mutation_5, mutation_4, mutation_3, mutation_2, mutation_1]
  2612. config = self._validate_config_and_fill_distribution(config)
  2613. # Silly, but attr-dicts like Easydicts don't serialize correctly to yaml.
  2614. # This sanitizes them with a round trip pass through json to get a regular dict.
  2615. class NonOctalStringDumper(yaml.Dumper):
  2616. """Prevents strings containing non-octal values like "008" and "009" from being converted to numbers in in the yaml string saved as the sweep config."""
  2617. def represent_scalar(self, tag, value, style=None):
  2618. if (
  2619. tag == "tag:yaml.org,2002:str"
  2620. and value.startswith("0")
  2621. and len(value) > 1
  2622. ):
  2623. return super().represent_scalar(tag, value, style="'")
  2624. return super().represent_scalar(tag, value, style)
  2625. config_str = yaml.dump(
  2626. json.loads(json.dumps(config)), Dumper=NonOctalStringDumper
  2627. )
  2628. filters = None
  2629. if prior_runs:
  2630. filters = json.dumps({"$or": [{"name": r} for r in prior_runs]})
  2631. err: Exception | None = None
  2632. for mutation in mutations:
  2633. try:
  2634. variables = {
  2635. "id": obj_id,
  2636. "config": config_str,
  2637. "description": config.get("description"),
  2638. "entityName": entity or self.settings("entity"),
  2639. "projectName": project or self.settings("project"),
  2640. "controller": controller,
  2641. "launchScheduler": launch_scheduler,
  2642. "templateVariableValues": json.dumps(template_variable_values),
  2643. "scheduler": scheduler,
  2644. "priorRunsFilters": filters,
  2645. "displayName": display_name,
  2646. }
  2647. if state:
  2648. variables["state"] = state
  2649. response = self.gql(
  2650. mutation,
  2651. variable_values=variables,
  2652. check_retry_fn=util.no_retry_4xx,
  2653. )
  2654. except UsageError:
  2655. raise
  2656. except Exception as e:
  2657. # graphql schema exception is generic
  2658. err = e
  2659. continue
  2660. err = None
  2661. break
  2662. if err:
  2663. raise err
  2664. sweep: dict[str, dict[str, dict]] = response["upsertSweep"]["sweep"]
  2665. project_obj: dict[str, dict] = sweep.get("project", {})
  2666. if project_obj:
  2667. self.set_setting("project", project_obj["name"])
  2668. entity_obj: dict = project_obj.get("entity", {})
  2669. if entity_obj:
  2670. self.set_setting("entity", entity_obj["name"])
  2671. warnings = response["upsertSweep"].get("configValidationWarnings", [])
  2672. return response["upsertSweep"]["sweep"]["name"], warnings
  2673. @staticmethod
  2674. def file_current(fname: str, md5: B64MD5) -> bool:
  2675. """Checksum a file and compare the md5 with the known md5."""
  2676. return os.path.isfile(fname) and md5_file_b64(fname) == md5
  2677. @normalize_exceptions
  2678. def pull(
  2679. self, project: str, run: str | None = None, entity: str | None = None
  2680. ) -> list[requests.Response]:
  2681. """Download files from W&B.
  2682. Args:
  2683. project (str): The project to download
  2684. run (str, optional): The run to upload to
  2685. entity (str, optional): The entity to scope this project to. Defaults to wandb models
  2686. Returns:
  2687. The `requests` library response object
  2688. """
  2689. project, run = self.parse_slug(project, run=run)
  2690. urls = self.download_urls(project, run, entity)
  2691. responses = []
  2692. for filename in urls:
  2693. _, response = self.download_write_file(urls[filename])
  2694. if response:
  2695. responses.append(response)
  2696. return responses
  2697. def get_project(self) -> str:
  2698. project: str = self.default_settings.get("project") or self.settings("project")
  2699. return project
  2700. @normalize_exceptions
  2701. def push(
  2702. self,
  2703. files: list[str] | dict[str, IO],
  2704. run: str | None = None,
  2705. entity: str | None = None,
  2706. project: str | None = None,
  2707. description: str | None = None,
  2708. force: bool = True,
  2709. progress: TextIO | Literal[False] = False,
  2710. ) -> list[requests.Response | None]:
  2711. """Uploads multiple files to W&B.
  2712. Args:
  2713. files (list or dict): The filenames to upload, when dict the values are open files
  2714. run (str, optional): The run to upload to
  2715. entity (str, optional): The entity to scope this project to. Defaults to wandb models
  2716. project (str, optional): The name of the project to upload to. Defaults to the one in settings.
  2717. description (str, optional): The description of the changes
  2718. force (bool, optional): Whether to prevent push if git has uncommitted changes
  2719. progress (callable, or stream): If callable, will be called with (chunk_bytes,
  2720. total_bytes) as argument. If TextIO, renders a progress bar to it.
  2721. Returns:
  2722. A list of `requests.Response` objects
  2723. """
  2724. if project is None:
  2725. project = self.get_project()
  2726. if project is None:
  2727. raise CommError("No project configured.")
  2728. if run is None:
  2729. run = self.current_run_id
  2730. # TODO(adrian): we use a retriable version of self.upload_file() so
  2731. # will never retry self.upload_urls() here. Instead, maybe we should
  2732. # make push itself retriable.
  2733. _, upload_headers, result = self.upload_urls(
  2734. project,
  2735. files,
  2736. run,
  2737. entity,
  2738. )
  2739. extra_headers = {}
  2740. for upload_header in upload_headers:
  2741. key, val = upload_header.split(":", 1)
  2742. extra_headers[key] = val
  2743. responses = []
  2744. for file_name, file_info in result.items():
  2745. file_url = file_info["uploadUrl"]
  2746. # If the upload URL is relative, fill it in with the base URL,
  2747. # since it's a proxied file store like the on-prem VM.
  2748. if file_url.startswith("/"):
  2749. file_url = f"{self.api_url}{file_url}"
  2750. try:
  2751. # To handle Windows paths
  2752. # TODO: this doesn't handle absolute paths...
  2753. normal_name = os.path.join(*file_name.split("/"))
  2754. open_file = (
  2755. files[file_name]
  2756. if isinstance(files, dict)
  2757. else open(normal_name, "rb")
  2758. )
  2759. except OSError:
  2760. print(f"{file_name} does not exist") # noqa: T201
  2761. continue
  2762. if progress is False:
  2763. responses.append(
  2764. self.upload_file_retry(
  2765. file_info["uploadUrl"], open_file, extra_headers=extra_headers
  2766. )
  2767. )
  2768. else:
  2769. if callable(progress):
  2770. responses.append( # type: ignore
  2771. self.upload_file_retry(
  2772. file_url, open_file, progress, extra_headers=extra_headers
  2773. )
  2774. )
  2775. else:
  2776. length = os.fstat(open_file.fileno()).st_size
  2777. with click.progressbar( # type: ignore
  2778. file=progress,
  2779. length=length,
  2780. label=f"Uploading file: {file_name}",
  2781. fill_char=click.style("&", fg="green"),
  2782. ) as bar:
  2783. responses.append(
  2784. self.upload_file_retry(
  2785. file_url,
  2786. open_file,
  2787. lambda bites, _: bar.update(bites),
  2788. extra_headers=extra_headers,
  2789. )
  2790. )
  2791. open_file.close()
  2792. return responses
  2793. def link_artifact(
  2794. self,
  2795. client_id: str,
  2796. server_id: str,
  2797. portfolio_name: str,
  2798. entity: str,
  2799. project: str,
  2800. aliases: Sequence[str],
  2801. organization: str,
  2802. ) -> dict[str, Any]:
  2803. from wandb.sdk.artifacts._validators import is_artifact_registry_project
  2804. template = """
  2805. mutation LinkArtifact(
  2806. $artifactPortfolioName: String!,
  2807. $entityName: String!,
  2808. $projectName: String!,
  2809. $aliases: [ArtifactAliasInput!],
  2810. ID_TYPE
  2811. ) {
  2812. linkArtifact(input: {
  2813. artifactPortfolioName: $artifactPortfolioName,
  2814. entityName: $entityName,
  2815. projectName: $projectName,
  2816. aliases: $aliases,
  2817. ID_VALUE
  2818. }) {
  2819. versionIndex
  2820. }
  2821. }
  2822. """
  2823. org_entity = ""
  2824. if is_artifact_registry_project(project):
  2825. try:
  2826. org_entity = self._resolve_org_entity_name(
  2827. entity=entity, organization=organization
  2828. )
  2829. except ValueError as e:
  2830. wandb.termerror(str(e))
  2831. raise
  2832. def replace(a: str, b: str) -> None:
  2833. nonlocal template
  2834. template = template.replace(a, b)
  2835. if server_id:
  2836. replace("ID_TYPE", "$artifactID: ID")
  2837. replace("ID_VALUE", "artifactID: $artifactID")
  2838. elif client_id:
  2839. replace("ID_TYPE", "$clientID: ID")
  2840. replace("ID_VALUE", "clientID: $clientID")
  2841. variable_values = {
  2842. "clientID": client_id,
  2843. "artifactID": server_id,
  2844. "artifactPortfolioName": portfolio_name,
  2845. "entityName": org_entity or entity,
  2846. "projectName": project,
  2847. "aliases": [
  2848. {"alias": alias, "artifactCollectionName": portfolio_name}
  2849. for alias in aliases
  2850. ],
  2851. }
  2852. mutation = gql(template)
  2853. response = self.gql(mutation, variable_values=variable_values)
  2854. link_artifact: dict[str, Any] = response["linkArtifact"]
  2855. return link_artifact
  2856. def _resolve_org_entity_name(self, entity: str, organization: str = "") -> str:
  2857. # resolveOrgEntityName fetches the portfolio's org entity's name.
  2858. #
  2859. # The organization parameter may be empty, an org's display name, or an org entity name.
  2860. #
  2861. # If the server doesn't support fetching the org name of a portfolio, then this returns
  2862. # the organization parameter, or an error if it is empty. Otherwise, this returns the
  2863. # fetched value after validating that the given organization, if not empty, matches
  2864. # either the org's display or entity name.
  2865. if not entity:
  2866. raise ValueError("Entity name is required to resolve org entity name.")
  2867. orgs_from_entity = self._fetch_orgs_and_org_entities_from_entity(entity)
  2868. if organization:
  2869. return _match_org_with_fetched_org_entities(organization, orgs_from_entity)
  2870. # If no input organization provided, error if entity belongs to multiple orgs because we
  2871. # cannot determine which one to use.
  2872. if len(orgs_from_entity) > 1:
  2873. raise ValueError(
  2874. f"Personal entity {entity!r} belongs to multiple organizations "
  2875. "and cannot be used without specifying the organization name. "
  2876. "Please specify the organization in the Registry path or use a team entity in the entity settings."
  2877. )
  2878. return orgs_from_entity[0].entity_name
  2879. def _fetch_orgs_and_org_entities_from_entity(self, entity: str) -> list[_OrgNames]:
  2880. """Fetches organization entity names and display names for a given entity.
  2881. Args:
  2882. entity (str): Entity name to lookup. Can be either a personal or team entity.
  2883. Returns:
  2884. list[_OrgNames]: list of _OrgNames tuples. (_OrgNames(entity_name, display_name))
  2885. Raises:
  2886. ValueError: If entity is not found, has no organizations, or other validation errors.
  2887. """
  2888. query = gql(
  2889. """
  2890. query FetchOrgEntityFromEntity($entityName: String!) {
  2891. entity(name: $entityName) {
  2892. organization {
  2893. name
  2894. orgEntity {
  2895. name
  2896. }
  2897. }
  2898. user {
  2899. organizations {
  2900. name
  2901. orgEntity {
  2902. name
  2903. }
  2904. }
  2905. }
  2906. }
  2907. }
  2908. """
  2909. )
  2910. response = self.gql(
  2911. query,
  2912. variable_values={
  2913. "entityName": entity,
  2914. },
  2915. )
  2916. # Parse organization from response
  2917. entity_resp = response["entity"]["organization"]
  2918. user_resp = response["entity"]["user"]
  2919. # Check for organization under team/org entity type
  2920. if entity_resp:
  2921. org_name = entity_resp.get("name")
  2922. org_entity_name = entity_resp.get("orgEntity") and entity_resp[
  2923. "orgEntity"
  2924. ].get("name")
  2925. if not org_name or not org_entity_name:
  2926. raise ValueError(
  2927. f"Unable to find an organization under entity {entity!r}."
  2928. )
  2929. return [_OrgNames(entity_name=org_entity_name, display_name=org_name)]
  2930. # Check for organization under personal entity type, where a user can belong to multiple orgs
  2931. elif user_resp:
  2932. orgs = user_resp.get("organizations", [])
  2933. org_entities_return = [
  2934. _OrgNames(
  2935. entity_name=org["orgEntity"]["name"], display_name=org["name"]
  2936. )
  2937. for org in orgs
  2938. if org.get("orgEntity") and org.get("name")
  2939. ]
  2940. if not org_entities_return:
  2941. raise ValueError(
  2942. f"Unable to resolve an organization associated with personal entity: {entity!r}. "
  2943. "This could be because its a personal entity that doesn't belong to any organizations. "
  2944. "Please specify the organization in the Registry path or use a team entity in the entity settings."
  2945. )
  2946. return org_entities_return
  2947. else:
  2948. raise ValueError(f"Unable to find an organization under entity {entity!r}.")
  2949. def _construct_use_artifact_query(
  2950. self,
  2951. artifact_id: str,
  2952. entity_name: str | None = None,
  2953. project_name: str | None = None,
  2954. run_name: str | None = None,
  2955. use_as: str | None = None,
  2956. artifact_entity_name: str | None = None,
  2957. artifact_project_name: str | None = None,
  2958. ) -> tuple[Document, dict[str, Any]]:
  2959. query_vars = [
  2960. "$entityName: String!",
  2961. "$projectName: String!",
  2962. "$runName: String!",
  2963. "$artifactID: ID!",
  2964. ]
  2965. query_args = [
  2966. "entityName: $entityName",
  2967. "projectName: $projectName",
  2968. "runName: $runName",
  2969. "artifactID: $artifactID",
  2970. ]
  2971. if use_as:
  2972. query_vars.append("$usedAs: String")
  2973. query_args.append("usedAs: $usedAs")
  2974. entity_name = entity_name or self.settings("entity")
  2975. project_name = project_name or self.settings("project")
  2976. run_name = run_name or self.current_run_id
  2977. variable_values: dict[str, Any] = {
  2978. "entityName": entity_name,
  2979. "projectName": project_name,
  2980. "runName": run_name,
  2981. "artifactID": artifact_id,
  2982. "usedAs": use_as,
  2983. }
  2984. server_allows_entity_project_information = self._server_supports(
  2985. ServerFeature.USE_ARTIFACT_WITH_ENTITY_AND_PROJECT_INFORMATION
  2986. )
  2987. if server_allows_entity_project_information:
  2988. query_vars.extend(
  2989. [
  2990. "$artifactEntityName: String",
  2991. "$artifactProjectName: String",
  2992. ]
  2993. )
  2994. query_args.extend(
  2995. [
  2996. "artifactEntityName: $artifactEntityName",
  2997. "artifactProjectName: $artifactProjectName",
  2998. ]
  2999. )
  3000. variable_values["artifactEntityName"] = artifact_entity_name
  3001. variable_values["artifactProjectName"] = artifact_project_name
  3002. vars_str = ", ".join(query_vars)
  3003. args_str = ", ".join(query_args)
  3004. query = gql(
  3005. f"""
  3006. mutation UseArtifact({vars_str}) {{
  3007. useArtifact(input: {{{args_str}}}) {{
  3008. artifact {{
  3009. id
  3010. digest
  3011. description
  3012. state
  3013. createdAt
  3014. metadata
  3015. }}
  3016. }}
  3017. }}
  3018. """
  3019. )
  3020. return query, variable_values
  3021. def use_artifact(
  3022. self,
  3023. artifact_id: str,
  3024. entity_name: str | None = None,
  3025. project_name: str | None = None,
  3026. run_name: str | None = None,
  3027. artifact_entity_name: str | None = None,
  3028. artifact_project_name: str | None = None,
  3029. use_as: str | None = None,
  3030. ) -> dict[str, Any] | None:
  3031. query, variable_values = self._construct_use_artifact_query(
  3032. artifact_id,
  3033. entity_name,
  3034. project_name,
  3035. run_name,
  3036. use_as,
  3037. artifact_entity_name,
  3038. artifact_project_name,
  3039. )
  3040. response = self.gql(query, variable_values)
  3041. if response["useArtifact"]["artifact"]:
  3042. artifact: dict[str, Any] = response["useArtifact"]["artifact"]
  3043. return artifact
  3044. return None
  3045. def create_artifact_type(
  3046. self,
  3047. artifact_type_name: str,
  3048. entity_name: str | None = None,
  3049. project_name: str | None = None,
  3050. description: str | None = None,
  3051. ) -> str | None:
  3052. mutation = gql(
  3053. """
  3054. mutation CreateArtifactType(
  3055. $entityName: String!,
  3056. $projectName: String!,
  3057. $artifactTypeName: String!,
  3058. $description: String
  3059. ) {
  3060. createArtifactType(input: {
  3061. entityName: $entityName,
  3062. projectName: $projectName,
  3063. name: $artifactTypeName,
  3064. description: $description
  3065. }) {
  3066. artifactType {
  3067. id
  3068. }
  3069. }
  3070. }
  3071. """
  3072. )
  3073. entity_name = entity_name or self.settings("entity")
  3074. project_name = project_name or self.settings("project")
  3075. response = self.gql(
  3076. mutation,
  3077. variable_values={
  3078. "entityName": entity_name,
  3079. "projectName": project_name,
  3080. "artifactTypeName": artifact_type_name,
  3081. "description": description,
  3082. },
  3083. )
  3084. _id: str | None = response["createArtifactType"]["artifactType"]["id"]
  3085. return _id
  3086. def _get_create_artifact_mutation(
  3087. self,
  3088. history_step: int | None,
  3089. distributed_id: str | None,
  3090. ) -> str:
  3091. types = ""
  3092. values = ""
  3093. if history_step not in [0, None]:
  3094. types += "$historyStep: Int64!,"
  3095. values += "historyStep: $historyStep,"
  3096. if distributed_id:
  3097. types += "$distributedID: String,"
  3098. values += "distributedID: $distributedID,"
  3099. query_template = """
  3100. mutation CreateArtifact(
  3101. $artifactTypeName: String!,
  3102. $artifactCollectionNames: [String!],
  3103. $entityName: String!,
  3104. $projectName: String!,
  3105. $runName: String,
  3106. $description: String,
  3107. $digest: String!,
  3108. $aliases: [ArtifactAliasInput!],
  3109. $metadata: JSONString,
  3110. $clientID: ID,
  3111. $sequenceClientID: ID,
  3112. $ttlDurationSeconds: Int64,
  3113. $tags: [TagInput!],
  3114. _CREATE_ARTIFACT_ADDITIONAL_TYPE_
  3115. ) {
  3116. createArtifact(input: {
  3117. artifactTypeName: $artifactTypeName,
  3118. artifactCollectionNames: $artifactCollectionNames,
  3119. entityName: $entityName,
  3120. projectName: $projectName,
  3121. runName: $runName,
  3122. description: $description,
  3123. digest: $digest,
  3124. digestAlgorithm: MANIFEST_MD5,
  3125. aliases: $aliases,
  3126. metadata: $metadata,
  3127. clientID: $clientID,
  3128. sequenceClientID: $sequenceClientID,
  3129. enableDigestDeduplication: true,
  3130. ttlDurationSeconds: $ttlDurationSeconds,
  3131. tags: $tags,
  3132. _CREATE_ARTIFACT_ADDITIONAL_VALUE_
  3133. }) {
  3134. artifact {
  3135. id
  3136. state
  3137. artifactSequence {
  3138. id
  3139. latestArtifact {
  3140. id
  3141. versionIndex
  3142. }
  3143. }
  3144. }
  3145. }
  3146. }
  3147. """
  3148. return query_template.replace(
  3149. "_CREATE_ARTIFACT_ADDITIONAL_TYPE_", types
  3150. ).replace("_CREATE_ARTIFACT_ADDITIONAL_VALUE_", values)
  3151. def create_artifact(
  3152. self,
  3153. artifact_type_name: str,
  3154. artifact_collection_name: str,
  3155. digest: str,
  3156. client_id: str | None = None,
  3157. sequence_client_id: str | None = None,
  3158. entity_name: str | None = None,
  3159. project_name: str | None = None,
  3160. run_name: str | None = None,
  3161. description: str | None = None,
  3162. metadata: dict | None = None,
  3163. ttl_duration_seconds: int | None = None,
  3164. aliases: list[dict[str, str]] | None = None,
  3165. tags: list[dict[str, str]] | None = None,
  3166. distributed_id: str | None = None,
  3167. is_user_created: bool | None = False,
  3168. history_step: int | None = None,
  3169. ) -> tuple[dict, dict]:
  3170. query_template = self._get_create_artifact_mutation(
  3171. history_step,
  3172. distributed_id,
  3173. )
  3174. entity_name = entity_name or self.settings("entity")
  3175. project_name = project_name or self.settings("project")
  3176. if not is_user_created:
  3177. run_name = run_name or self.current_run_id
  3178. mutation = gql(query_template)
  3179. response = self.gql(
  3180. mutation,
  3181. variable_values={
  3182. "entityName": entity_name,
  3183. "projectName": project_name,
  3184. "runName": run_name,
  3185. "artifactTypeName": artifact_type_name,
  3186. "artifactCollectionNames": [artifact_collection_name],
  3187. "clientID": client_id,
  3188. "sequenceClientID": sequence_client_id,
  3189. "digest": digest,
  3190. "description": description,
  3191. "aliases": list(aliases or []),
  3192. "tags": list(tags or []),
  3193. "metadata": json.dumps(util.make_safe_for_json(metadata))
  3194. if metadata
  3195. else None,
  3196. "ttlDurationSeconds": ttl_duration_seconds,
  3197. "distributedID": distributed_id,
  3198. "historyStep": history_step,
  3199. },
  3200. )
  3201. av = response["createArtifact"]["artifact"]
  3202. latest = response["createArtifact"]["artifact"]["artifactSequence"].get(
  3203. "latestArtifact"
  3204. )
  3205. return av, latest
  3206. def commit_artifact(self, artifact_id: str) -> _Response:
  3207. mutation = gql(
  3208. """
  3209. mutation CommitArtifact(
  3210. $artifactID: ID!,
  3211. ) {
  3212. commitArtifact(input: {
  3213. artifactID: $artifactID,
  3214. }) {
  3215. artifact {
  3216. id
  3217. digest
  3218. }
  3219. }
  3220. }
  3221. """
  3222. )
  3223. response: _Response = self.gql(
  3224. mutation,
  3225. variable_values={"artifactID": artifact_id},
  3226. timeout=60,
  3227. )
  3228. return response
  3229. def complete_multipart_upload_artifact(
  3230. self,
  3231. artifact_id: str,
  3232. storage_path: str,
  3233. completed_parts: list[dict[str, Any]],
  3234. upload_id: str | None,
  3235. complete_multipart_action: str = "Complete",
  3236. ) -> str | None:
  3237. mutation = gql(
  3238. """
  3239. mutation CompleteMultipartUploadArtifact(
  3240. $completeMultipartAction: CompleteMultipartAction!,
  3241. $completedParts: [UploadPartsInput!]!,
  3242. $artifactID: ID!
  3243. $storagePath: String!
  3244. $uploadID: String!
  3245. ) {
  3246. completeMultipartUploadArtifact(
  3247. input: {
  3248. completeMultipartAction: $completeMultipartAction,
  3249. completedParts: $completedParts,
  3250. artifactID: $artifactID,
  3251. storagePath: $storagePath
  3252. uploadID: $uploadID
  3253. }
  3254. ) {
  3255. digest
  3256. }
  3257. }
  3258. """
  3259. )
  3260. response = self.gql(
  3261. mutation,
  3262. variable_values={
  3263. "completeMultipartAction": complete_multipart_action,
  3264. "artifactID": artifact_id,
  3265. "storagePath": storage_path,
  3266. "completedParts": completed_parts,
  3267. "uploadID": upload_id,
  3268. },
  3269. )
  3270. digest: str | None = response["completeMultipartUploadArtifact"]["digest"]
  3271. return digest
  3272. def create_artifact_manifest(
  3273. self,
  3274. name: str,
  3275. digest: str,
  3276. artifact_id: str | None,
  3277. base_artifact_id: str | None = None,
  3278. entity: str | None = None,
  3279. project: str | None = None,
  3280. run: str | None = None,
  3281. include_upload: bool = True,
  3282. type: str = "FULL",
  3283. ) -> tuple[str, dict[str, Any]]:
  3284. mutation = gql(
  3285. """
  3286. mutation CreateArtifactManifest(
  3287. $name: String!,
  3288. $digest: String!,
  3289. $artifactID: ID!,
  3290. $baseArtifactID: ID,
  3291. $entityName: String!,
  3292. $projectName: String!,
  3293. $runName: String!,
  3294. $includeUpload: Boolean!,
  3295. {}
  3296. ) {{
  3297. createArtifactManifest(input: {{
  3298. name: $name,
  3299. digest: $digest,
  3300. artifactID: $artifactID,
  3301. baseArtifactID: $baseArtifactID,
  3302. entityName: $entityName,
  3303. projectName: $projectName,
  3304. runName: $runName,
  3305. {}
  3306. }}) {{
  3307. artifactManifest {{
  3308. id
  3309. file {{
  3310. id
  3311. name
  3312. displayName
  3313. uploadUrl @include(if: $includeUpload)
  3314. uploadHeaders @include(if: $includeUpload)
  3315. }}
  3316. }}
  3317. }}
  3318. }}
  3319. """.format(
  3320. "$type: ArtifactManifestType = FULL" if type != "FULL" else "",
  3321. "type: $type" if type != "FULL" else "",
  3322. )
  3323. )
  3324. entity_name = entity or self.settings("entity")
  3325. project_name = project or self.settings("project")
  3326. run_name = run or self.current_run_id
  3327. response = self.gql(
  3328. mutation,
  3329. variable_values={
  3330. "name": name,
  3331. "digest": digest,
  3332. "artifactID": artifact_id,
  3333. "baseArtifactID": base_artifact_id,
  3334. "entityName": entity_name,
  3335. "projectName": project_name,
  3336. "runName": run_name,
  3337. "includeUpload": include_upload,
  3338. "type": type,
  3339. },
  3340. )
  3341. return (
  3342. response["createArtifactManifest"]["artifactManifest"]["id"],
  3343. response["createArtifactManifest"]["artifactManifest"]["file"],
  3344. )
  3345. def update_artifact_manifest(
  3346. self,
  3347. artifact_manifest_id: str,
  3348. base_artifact_id: str | None = None,
  3349. digest: str | None = None,
  3350. include_upload: bool | None = True,
  3351. ) -> tuple[str, dict[str, Any]]:
  3352. mutation = gql(
  3353. """
  3354. mutation UpdateArtifactManifest(
  3355. $artifactManifestID: ID!,
  3356. $digest: String,
  3357. $baseArtifactID: ID,
  3358. $includeUpload: Boolean!,
  3359. ) {
  3360. updateArtifactManifest(input: {
  3361. artifactManifestID: $artifactManifestID,
  3362. digest: $digest,
  3363. baseArtifactID: $baseArtifactID,
  3364. }) {
  3365. artifactManifest {
  3366. id
  3367. file {
  3368. id
  3369. name
  3370. displayName
  3371. uploadUrl @include(if: $includeUpload)
  3372. uploadHeaders @include(if: $includeUpload)
  3373. }
  3374. }
  3375. }
  3376. }
  3377. """
  3378. )
  3379. response = self.gql(
  3380. mutation,
  3381. variable_values={
  3382. "artifactManifestID": artifact_manifest_id,
  3383. "digest": digest,
  3384. "baseArtifactID": base_artifact_id,
  3385. "includeUpload": include_upload,
  3386. },
  3387. )
  3388. return (
  3389. response["updateArtifactManifest"]["artifactManifest"]["id"],
  3390. response["updateArtifactManifest"]["artifactManifest"]["file"],
  3391. )
  3392. def update_artifact_metadata(
  3393. self, artifact_id: str, metadata: dict[str, Any]
  3394. ) -> dict[str, Any]:
  3395. """Set the metadata of the given artifact version."""
  3396. mutation = gql(
  3397. """
  3398. mutation UpdateArtifact(
  3399. $artifactID: ID!,
  3400. $metadata: JSONString,
  3401. ) {
  3402. updateArtifact(input: {
  3403. artifactID: $artifactID,
  3404. metadata: $metadata,
  3405. }) {
  3406. artifact {
  3407. id
  3408. }
  3409. }
  3410. }
  3411. """
  3412. )
  3413. response = self.gql(
  3414. mutation,
  3415. variable_values={
  3416. "artifactID": artifact_id,
  3417. "metadata": json.dumps(metadata),
  3418. },
  3419. )
  3420. return response["updateArtifact"]["artifact"]
  3421. def _resolve_client_id(
  3422. self,
  3423. client_id: str,
  3424. ) -> str | None:
  3425. if client_id in self._client_id_mapping:
  3426. return self._client_id_mapping[client_id]
  3427. query = gql(
  3428. """
  3429. query ClientIDMapping($clientID: ID!) {
  3430. clientIDMapping(clientID: $clientID) {
  3431. serverID
  3432. }
  3433. }
  3434. """
  3435. )
  3436. response = self.gql(
  3437. query,
  3438. variable_values={
  3439. "clientID": client_id,
  3440. },
  3441. )
  3442. server_id = None
  3443. if response is not None:
  3444. client_id_mapping = response.get("clientIDMapping")
  3445. if client_id_mapping is not None:
  3446. server_id = client_id_mapping.get("serverID")
  3447. if server_id is not None:
  3448. self._client_id_mapping[client_id] = server_id
  3449. return server_id
  3450. @normalize_exceptions
  3451. def create_artifact_files(
  3452. self, artifact_files: Iterable[CreateArtifactFileSpecInput]
  3453. ) -> Mapping[str, CreateArtifactFilesResponseFile]:
  3454. query_template = """
  3455. mutation CreateArtifactFiles(
  3456. $storageLayout: ArtifactStorageLayout!
  3457. $artifactFiles: [CreateArtifactFileSpecInput!]!
  3458. ) {
  3459. createArtifactFiles(input: {
  3460. artifactFiles: $artifactFiles,
  3461. storageLayout: $storageLayout,
  3462. }) {
  3463. files {
  3464. edges {
  3465. node {
  3466. id
  3467. name
  3468. displayName
  3469. uploadUrl
  3470. uploadHeaders
  3471. storagePath
  3472. uploadMultipartUrls {
  3473. uploadID
  3474. uploadUrlParts {
  3475. partNumber
  3476. uploadUrl
  3477. }
  3478. }
  3479. artifact {
  3480. id
  3481. }
  3482. }
  3483. }
  3484. }
  3485. }
  3486. }
  3487. """
  3488. # TODO: we should use constants here from interface/artifacts.py
  3489. # but probably don't want the dependency. We're going to remove
  3490. # this setting in a future release, so I'm just hard-coding the strings.
  3491. storage_layout = "V2"
  3492. if env.get_use_v1_artifacts():
  3493. storage_layout = "V1"
  3494. mutation = gql(query_template)
  3495. response = self.gql(
  3496. mutation,
  3497. variable_values={
  3498. "storageLayout": storage_layout,
  3499. "artifactFiles": [af for af in artifact_files],
  3500. },
  3501. )
  3502. result = {}
  3503. for edge in response["createArtifactFiles"]["files"]["edges"]:
  3504. node = edge["node"]
  3505. result[node["displayName"]] = node
  3506. return result
  3507. @normalize_exceptions
  3508. def notify_scriptable_run_alert(
  3509. self,
  3510. title: str,
  3511. text: str,
  3512. level: str | None = None,
  3513. wait_duration: Number | None = None,
  3514. ) -> bool:
  3515. mutation = gql(
  3516. """
  3517. mutation NotifyScriptableRunAlert(
  3518. $entityName: String!,
  3519. $projectName: String!,
  3520. $runName: String!,
  3521. $title: String!,
  3522. $text: String!,
  3523. $severity: AlertSeverity = INFO,
  3524. $waitDuration: Duration
  3525. ) {
  3526. notifyScriptableRunAlert(input: {
  3527. entityName: $entityName,
  3528. projectName: $projectName,
  3529. runName: $runName,
  3530. title: $title,
  3531. text: $text,
  3532. severity: $severity,
  3533. waitDuration: $waitDuration
  3534. }) {
  3535. success
  3536. }
  3537. }
  3538. """
  3539. )
  3540. response = self.gql(
  3541. mutation,
  3542. variable_values={
  3543. "entityName": self.settings("entity"),
  3544. "projectName": self.settings("project"),
  3545. "runName": self.current_run_id,
  3546. "title": title,
  3547. "text": text,
  3548. "severity": level,
  3549. "waitDuration": wait_duration,
  3550. },
  3551. )
  3552. success: bool = response["notifyScriptableRunAlert"]["success"]
  3553. return success
  3554. def get_sweep_state(
  3555. self, sweep: str, entity: str | None = None, project: str | None = None
  3556. ) -> SweepState:
  3557. query = gql(
  3558. """
  3559. query GetSweepState($entity: String, $project: String, $sweep: String!) {
  3560. project(name: $project, entityName: $entity) {
  3561. sweep(sweepName: $sweep) {
  3562. state
  3563. }
  3564. }
  3565. }
  3566. """
  3567. )
  3568. response = self.gql(
  3569. query,
  3570. variable_values={
  3571. "sweep": sweep,
  3572. "entity": entity or self.settings("entity"),
  3573. "project": project or self.settings("project"),
  3574. },
  3575. )
  3576. return response["project"]["sweep"]["state"]
  3577. def set_sweep_state(
  3578. self,
  3579. sweep: str,
  3580. state: SweepState,
  3581. entity: str | None = None,
  3582. project: str | None = None,
  3583. ) -> None:
  3584. assert state in ("RUNNING", "PAUSED", "CANCELED", "FINISHED")
  3585. s = self.sweep(sweep=sweep, entity=entity, project=project, specs="{}")
  3586. curr_state = s["state"].upper()
  3587. if state == "PAUSED" and curr_state not in ("PAUSED", "RUNNING"):
  3588. raise Exception(f"Cannot pause {curr_state.lower()} sweep.")
  3589. elif state != "RUNNING" and curr_state not in ("RUNNING", "PAUSED", "PENDING"):
  3590. raise Exception(f"Sweep already {curr_state.lower()}.")
  3591. sweep_id = s["id"]
  3592. mutation = gql(
  3593. """
  3594. mutation UpsertSweep(
  3595. $id: ID,
  3596. $state: String,
  3597. $entityName: String,
  3598. $projectName: String
  3599. ) {
  3600. upsertSweep(input: {
  3601. id: $id,
  3602. state: $state,
  3603. entityName: $entityName,
  3604. projectName: $projectName
  3605. }){
  3606. sweep {
  3607. name
  3608. }
  3609. }
  3610. }
  3611. """
  3612. )
  3613. self.gql(
  3614. mutation,
  3615. variable_values={
  3616. "id": sweep_id,
  3617. "state": state,
  3618. "entityName": entity or self.settings("entity"),
  3619. "projectName": project or self.settings("project"),
  3620. },
  3621. )
  3622. def stop_sweep(
  3623. self,
  3624. sweep: str,
  3625. entity: str | None = None,
  3626. project: str | None = None,
  3627. ) -> None:
  3628. """Finish the sweep to stop running new runs and let currently running runs finish."""
  3629. self.set_sweep_state(
  3630. sweep=sweep, state="FINISHED", entity=entity, project=project
  3631. )
  3632. def cancel_sweep(
  3633. self,
  3634. sweep: str,
  3635. entity: str | None = None,
  3636. project: str | None = None,
  3637. ) -> None:
  3638. """Cancel the sweep to kill all running runs and stop running new runs."""
  3639. self.set_sweep_state(
  3640. sweep=sweep, state="CANCELED", entity=entity, project=project
  3641. )
  3642. def pause_sweep(
  3643. self,
  3644. sweep: str,
  3645. entity: str | None = None,
  3646. project: str | None = None,
  3647. ) -> None:
  3648. """Pause the sweep to temporarily stop running new runs."""
  3649. self.set_sweep_state(
  3650. sweep=sweep, state="PAUSED", entity=entity, project=project
  3651. )
  3652. def resume_sweep(
  3653. self,
  3654. sweep: str,
  3655. entity: str | None = None,
  3656. project: str | None = None,
  3657. ) -> None:
  3658. """Resume the sweep to continue running new runs."""
  3659. self.set_sweep_state(
  3660. sweep=sweep, state="RUNNING", entity=entity, project=project
  3661. )
  3662. def _status_request(self, url: str, length: int) -> requests.Response:
  3663. """Ask google how much we've uploaded."""
  3664. import requests
  3665. check_httpclient_logger_handler()
  3666. return requests.put(
  3667. url=url,
  3668. headers={"Content-Length": "0", "Content-Range": f"bytes */{length}"},
  3669. )
  3670. def _flatten_edges(self, response: _Response) -> list[dict]:
  3671. """Return an array from the nested graphql relay structure."""
  3672. return [node["node"] for node in response["edges"]]
  3673. @normalize_exceptions
  3674. def stop_run(
  3675. self,
  3676. run_id: str,
  3677. ) -> bool:
  3678. mutation = gql(
  3679. """
  3680. mutation stopRun($id: ID!) {
  3681. stopRun(input: {
  3682. id: $id
  3683. }) {
  3684. clientMutationId
  3685. success
  3686. }
  3687. }
  3688. """
  3689. )
  3690. response = self.gql(
  3691. mutation,
  3692. variable_values={
  3693. "id": run_id,
  3694. },
  3695. )
  3696. success: bool = response["stopRun"].get("success")
  3697. return success
  3698. @normalize_exceptions
  3699. def create_custom_chart(
  3700. self,
  3701. entity: str,
  3702. name: str,
  3703. display_name: str,
  3704. spec_type: str,
  3705. access: str,
  3706. spec: str | Mapping[str, Any],
  3707. ) -> dict[str, Any] | None:
  3708. if not isinstance(spec, str):
  3709. spec = json.dumps(spec)
  3710. mutation = gql(
  3711. """
  3712. mutation CreateCustomChart(
  3713. $entity: String!
  3714. $name: String!
  3715. $displayName: String!
  3716. $type: String!
  3717. $access: String!
  3718. $spec: JSONString!
  3719. ) {
  3720. createCustomChart(
  3721. input: {
  3722. entity: $entity
  3723. name: $name
  3724. displayName: $displayName
  3725. type: $type
  3726. access: $access
  3727. spec: $spec
  3728. }
  3729. ) {
  3730. chart { id }
  3731. }
  3732. }
  3733. """
  3734. )
  3735. variable_values = {
  3736. "entity": entity,
  3737. "name": name,
  3738. "displayName": display_name,
  3739. "type": spec_type,
  3740. "access": access,
  3741. "spec": spec,
  3742. }
  3743. result: dict[str, Any] | None = self.gql(mutation, variable_values)[
  3744. "createCustomChart"
  3745. ]
  3746. return result