jupyter.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. from __future__ import annotations
  2. import json
  3. import logging
  4. import os
  5. import re
  6. import shutil
  7. import sys
  8. import traceback
  9. from base64 import b64encode
  10. from typing import Any
  11. import IPython
  12. import IPython.display
  13. import requests
  14. from IPython.core.magic import Magics, line_cell_magic, magics_class
  15. from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
  16. from requests.compat import urljoin
  17. import wandb
  18. import wandb.util
  19. from wandb.sdk import wandb_setup
  20. from wandb.sdk.lib import filesystem
  21. logger = logging.getLogger(__name__)
  22. def display_if_magic_is_used(run: wandb.Run) -> bool:
  23. """Display a run's page if the cell has the %%wandb cell magic.
  24. Args:
  25. run: The run to display.
  26. Returns:
  27. Whether the %%wandb cell magic was present.
  28. """
  29. if not _current_cell_wandb_magic:
  30. return False
  31. _current_cell_wandb_magic.display_if_allowed(run)
  32. return True
  33. class _WandbCellMagicState:
  34. """State for a cell with the %%wandb cell magic."""
  35. def __init__(self, *, height: int) -> None:
  36. """Initializes the %%wandb cell magic state.
  37. Args:
  38. height: The desired height for displayed iframes.
  39. """
  40. self._height = height
  41. self._already_displayed = False
  42. def display_if_allowed(self, run: wandb.Run) -> None:
  43. """Display a run's iframe if one is not already displayed.
  44. Args:
  45. run: The run to display.
  46. """
  47. if self._already_displayed:
  48. return
  49. self._already_displayed = True
  50. _display_wandb_run(run, height=self._height)
  51. _current_cell_wandb_magic: _WandbCellMagicState | None = None
  52. def _display_by_wandb_path(path: str, *, height: int) -> None:
  53. """Display a wandb object (usually in an iframe) given its URI.
  54. Args:
  55. path: A path to a run, sweep, project, report, etc.
  56. height: Height of the iframe in pixels.
  57. """
  58. api = wandb.Api()
  59. try:
  60. obj = api.from_path(path)
  61. IPython.display.display_html(
  62. obj.to_html(height=height),
  63. raw=True,
  64. )
  65. except wandb.Error:
  66. traceback.print_exc()
  67. IPython.display.display_html(
  68. f"Path {path!r} does not refer to a W&B object you can access.",
  69. raw=True,
  70. )
  71. def _display_wandb_run(run: wandb.Run, *, height: int) -> None:
  72. """Display a run (usually in an iframe).
  73. Args:
  74. run: The run to display.
  75. height: Height of the iframe in pixels.
  76. """
  77. IPython.display.display_html(
  78. run.to_html(height=height),
  79. raw=True,
  80. )
  81. @magics_class
  82. class WandBMagics(Magics):
  83. def __init__(self, shell):
  84. super().__init__(shell)
  85. @magic_arguments()
  86. @argument(
  87. "path",
  88. default=None,
  89. nargs="?",
  90. help="The path to a resource you want to display.",
  91. )
  92. @argument(
  93. "-h",
  94. "--height",
  95. default=420,
  96. type=int,
  97. help="The height of the iframe in pixels.",
  98. )
  99. @line_cell_magic
  100. def wandb(self, line: str, cell: str | None = None) -> None:
  101. """Display wandb resources in Jupyter.
  102. This can be used as a line magic:
  103. %wandb USERNAME/PROJECT/runs/RUN_ID
  104. Or as a cell magic:
  105. %%wandb -h 1024
  106. with wandb.init() as run:
  107. run.log({"loss": 1})
  108. """
  109. global _current_cell_wandb_magic
  110. args = parse_argstring(self.wandb, line)
  111. path: str | None = args.path
  112. height: int = args.height
  113. if path:
  114. _display_by_wandb_path(path, height=height)
  115. displayed = True
  116. elif run := wandb_setup.singleton().most_recent_active_run:
  117. _display_wandb_run(run, height=height)
  118. displayed = True
  119. else:
  120. displayed = False
  121. # If this is being used as a line magic ("%wandb"), we are done.
  122. # When used as a cell magic ("%%wandb"), we must run the cell.
  123. if cell is None:
  124. return
  125. if not displayed:
  126. _current_cell_wandb_magic = _WandbCellMagicState(height=height)
  127. try:
  128. IPython.get_ipython().run_cell(cell)
  129. finally:
  130. _current_cell_wandb_magic = None
  131. def notebook_metadata_from_jupyter_servers_and_kernel_id():
  132. # When running in VS Code's notebook extension,
  133. # the extension creates a temporary file to start the kernel.
  134. # This file is not actually the same as the notebook file.
  135. #
  136. # The real notebook path is stored in the user namespace
  137. # under the key "__vsc_ipynb_file__"
  138. try:
  139. from IPython import get_ipython
  140. ipython = get_ipython()
  141. if ipython is not None:
  142. notebook_path = ipython.kernel.shell.user_ns.get("__vsc_ipynb_file__")
  143. if notebook_path:
  144. return {
  145. "root": os.path.dirname(notebook_path),
  146. "path": notebook_path,
  147. "name": os.path.basename(notebook_path),
  148. }
  149. except ModuleNotFoundError:
  150. return None
  151. servers, kernel_id = jupyter_servers_and_kernel_id()
  152. for s in servers:
  153. if s.get("password"):
  154. raise ValueError("Can't query password protected kernel")
  155. res = requests.get(
  156. urljoin(s["url"], "api/sessions"), params={"token": s.get("token", "")}
  157. ).json()
  158. for nn in res:
  159. if (
  160. isinstance(nn, dict)
  161. and nn.get("kernel")
  162. and "notebook" in nn
  163. and nn["kernel"]["id"] == kernel_id
  164. ):
  165. return {
  166. "root": s.get("root_dir", s.get("notebook_dir", os.getcwd())),
  167. "path": nn["notebook"]["path"],
  168. "name": nn["notebook"]["name"],
  169. }
  170. if not kernel_id:
  171. return None
  172. def notebook_metadata(silent: bool) -> dict[str, str]:
  173. """Attempt to query jupyter for the path and name of the notebook file.
  174. This can handle different jupyter environments, specifically:
  175. 1. Colab
  176. 2. Kaggle
  177. 3. JupyterLab
  178. 4. Notebooks
  179. 5. Other?
  180. """
  181. error_message = (
  182. "Failed to detect the name of this notebook. You can set it manually"
  183. " with the WANDB_NOTEBOOK_NAME environment variable to enable code"
  184. " saving."
  185. )
  186. try:
  187. jupyter_metadata = notebook_metadata_from_jupyter_servers_and_kernel_id()
  188. # Colab:
  189. # request the most recent contents
  190. ipynb = attempt_colab_load_ipynb()
  191. if ipynb is not None and jupyter_metadata is not None:
  192. return {
  193. "root": "/content",
  194. "path": jupyter_metadata["path"],
  195. "name": jupyter_metadata["name"],
  196. }
  197. # Kaggle:
  198. if wandb.util._is_kaggle():
  199. # request the most recent contents
  200. ipynb = attempt_kaggle_load_ipynb()
  201. if ipynb:
  202. return {
  203. "root": "/kaggle/working",
  204. "path": ipynb["metadata"]["name"],
  205. "name": ipynb["metadata"]["name"],
  206. }
  207. if jupyter_metadata:
  208. return jupyter_metadata
  209. except Exception:
  210. logger.exception(error_message)
  211. wandb.termerror(error_message)
  212. return {}
  213. def jupyter_servers_and_kernel_id():
  214. """Return a list of servers and the current kernel_id.
  215. Used to query for the name of the notebook.
  216. """
  217. try:
  218. import ipykernel # type: ignore
  219. kernel_id = re.search(
  220. "kernel-(.*).json", ipykernel.connect.get_connection_file()
  221. ).group(1)
  222. # We're either in jupyterlab or a notebook, lets prefer the newer jupyter_server package
  223. serverapp = wandb.util.get_module("jupyter_server.serverapp")
  224. notebookapp = wandb.util.get_module("notebook.notebookapp")
  225. servers = []
  226. if serverapp is not None:
  227. servers.extend(list(serverapp.list_running_servers()))
  228. if notebookapp is not None:
  229. servers.extend(list(notebookapp.list_running_servers()))
  230. except (AttributeError, ValueError, ImportError):
  231. return [], None
  232. return servers, kernel_id
  233. def attempt_colab_load_ipynb():
  234. colab = wandb.util.get_module("google.colab")
  235. if colab:
  236. # This isn't thread safe, never call in a thread
  237. response = colab._message.blocking_request("get_ipynb", timeout_sec=5)
  238. if response:
  239. return response["ipynb"]
  240. def attempt_kaggle_load_ipynb():
  241. kaggle = wandb.util.get_module("kaggle_session")
  242. if not kaggle:
  243. return None
  244. try:
  245. client = kaggle.UserSessionClient()
  246. parsed = json.loads(client.get_exportable_ipynb()["source"])
  247. # TODO: couldn't find a way to get the name of the notebook...
  248. parsed["metadata"]["name"] = "kaggle.ipynb"
  249. except Exception:
  250. wandb.termerror("Unable to load kaggle notebook.")
  251. logger.exception("Unable to load kaggle notebook.")
  252. return None
  253. return parsed
  254. class Notebook:
  255. def __init__(self, settings: wandb.Settings) -> None:
  256. self.outputs: dict[int, Any] = {}
  257. self.settings = settings
  258. self.shell = IPython.get_ipython()
  259. def save_display(self, exc_count, data_with_metadata):
  260. self.outputs[exc_count] = self.outputs.get(exc_count, [])
  261. # byte values such as images need to be encoded in base64
  262. # otherwise nbformat.v4.new_output will throw a NotebookValidationError
  263. data = data_with_metadata["data"]
  264. b64_data = {}
  265. for key in data:
  266. val = data[key]
  267. if isinstance(val, bytes):
  268. b64_data[key] = b64encode(val).decode("utf-8")
  269. else:
  270. b64_data[key] = val
  271. self.outputs[exc_count].append(
  272. {"data": b64_data, "metadata": data_with_metadata["metadata"]}
  273. )
  274. def probe_ipynb(self):
  275. """Return notebook as dict or None."""
  276. relpath = self.settings.x_jupyter_path
  277. if relpath and os.path.exists(relpath):
  278. with open(relpath) as json_file:
  279. data = json.load(json_file)
  280. return data
  281. colab_ipynb = attempt_colab_load_ipynb()
  282. if colab_ipynb:
  283. return colab_ipynb
  284. kaggle_ipynb = attempt_kaggle_load_ipynb()
  285. if kaggle_ipynb and len(kaggle_ipynb["cells"]) > 0:
  286. return kaggle_ipynb
  287. return
  288. def save_ipynb(self) -> bool:
  289. if not self.settings.save_code:
  290. logger.info("not saving jupyter notebook")
  291. return False
  292. ret = False
  293. try:
  294. ret = self._save_ipynb()
  295. except Exception:
  296. wandb.termerror("Failed to save notebook.")
  297. logger.exception("Problem saving notebook.")
  298. return ret
  299. def _save_ipynb(self) -> bool:
  300. relpath = self.settings.x_jupyter_path
  301. logger.info("looking for notebook: %s", relpath)
  302. if relpath and os.path.exists(relpath):
  303. shutil.copy(
  304. relpath,
  305. os.path.join(self.settings._tmp_code_dir, os.path.basename(relpath)),
  306. )
  307. return True
  308. # TODO: likely only save if the code has changed
  309. colab_ipynb = attempt_colab_load_ipynb()
  310. if colab_ipynb:
  311. try:
  312. jupyter_metadata = (
  313. notebook_metadata_from_jupyter_servers_and_kernel_id()
  314. )
  315. nb_name = jupyter_metadata["name"]
  316. except Exception:
  317. nb_name = "colab.ipynb"
  318. if not nb_name.endswith(".ipynb"):
  319. nb_name += ".ipynb"
  320. with open(
  321. os.path.join(
  322. self.settings._tmp_code_dir,
  323. nb_name,
  324. ),
  325. "w",
  326. encoding="utf-8",
  327. ) as f:
  328. f.write(json.dumps(colab_ipynb))
  329. return True
  330. kaggle_ipynb = attempt_kaggle_load_ipynb()
  331. if kaggle_ipynb and len(kaggle_ipynb["cells"]) > 0:
  332. with open(
  333. os.path.join(
  334. self.settings._tmp_code_dir, kaggle_ipynb["metadata"]["name"]
  335. ),
  336. "w",
  337. encoding="utf-8",
  338. ) as f:
  339. f.write(json.dumps(kaggle_ipynb))
  340. return True
  341. return False
  342. def save_history(self, run: wandb.Run):
  343. """This saves all cell executions in the current session as a new notebook."""
  344. try:
  345. from nbformat import v4, validator, write # type: ignore
  346. except ImportError:
  347. wandb.termerror(
  348. "The nbformat package was not found."
  349. " It is required to save notebook history."
  350. )
  351. return
  352. # TODO: some tests didn't patch ipython properly?
  353. if self.shell is None:
  354. return
  355. cells = []
  356. hist = list(self.shell.history_manager.get_range(output=True))
  357. if len(hist) <= 1 or not self.settings.save_code:
  358. logger.info("not saving jupyter history")
  359. return
  360. try:
  361. for _, execution_count, exc in hist:
  362. if exc[1]:
  363. # TODO: capture stderr?
  364. outputs = [
  365. v4.new_output(output_type="stream", name="stdout", text=exc[1])
  366. ]
  367. else:
  368. outputs = []
  369. if self.outputs.get(execution_count):
  370. for out in self.outputs[execution_count]:
  371. outputs.append(
  372. v4.new_output(
  373. output_type="display_data",
  374. data=out["data"],
  375. metadata=out["metadata"] or {},
  376. )
  377. )
  378. cells.append(
  379. v4.new_code_cell(
  380. execution_count=execution_count, source=exc[0], outputs=outputs
  381. )
  382. )
  383. if hasattr(self.shell, "kernel"):
  384. language_info = self.shell.kernel.language_info
  385. else:
  386. language_info = {"name": "python", "version": sys.version}
  387. logger.info("saving %i cells to _session_history.ipynb", len(cells))
  388. nb = v4.new_notebook(
  389. cells=cells,
  390. metadata={
  391. "kernelspec": {
  392. "display_name": f"Python {sys.version_info[0]}",
  393. "name": f"python{sys.version_info[0]}",
  394. "language": "python",
  395. },
  396. "language_info": language_info,
  397. },
  398. )
  399. state_path = os.path.join("code", "_session_history.ipynb")
  400. run._set_config_wandb("session_history", state_path)
  401. filesystem.mkdir_exists_ok(os.path.join(self.settings.files_dir, "code"))
  402. with open(
  403. os.path.join(self.settings._tmp_code_dir, "_session_history.ipynb"),
  404. "w",
  405. encoding="utf-8",
  406. ) as f:
  407. write(nb, f, version=4)
  408. with open(
  409. os.path.join(self.settings.files_dir, state_path),
  410. "w",
  411. encoding="utf-8",
  412. ) as f:
  413. write(nb, f, version=4)
  414. except (OSError, validator.NotebookValidationError):
  415. wandb.termerror("Unable to save notebook session history.")
  416. logger.exception("Unable to save notebook session history.")
  417. def _load_ipython_extension(ipython):
  418. """Best-effort auto-registration of W&B magics in notebook contexts."""
  419. if ipython is None:
  420. return
  421. try:
  422. ipython.register_magics(WandBMagics)
  423. except Exception:
  424. logger.debug("Failed to register IPython magics.", exc_info=True)
  425. return