ipython.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from __future__ import annotations
  2. import logging
  3. import sys
  4. import warnings
  5. from typing import Literal
  6. import wandb
  7. PythonType = Literal["python", "ipython", "jupyter"]
  8. logger = logging.getLogger(__name__)
  9. def toggle_button(what="run"):
  10. """Returns the HTML for a button used to reveal the element following it.
  11. The element immediately after the button must have `display: none`.
  12. """
  13. return (
  14. "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">"
  15. f"Display W&B {what}"
  16. "</button>"
  17. )
  18. def _get_python_type() -> PythonType:
  19. if "IPython" not in sys.modules:
  20. return "python"
  21. try:
  22. from IPython import get_ipython # type: ignore
  23. # Calling get_ipython can cause an ImportError
  24. if get_ipython() is None:
  25. return "python"
  26. except ImportError:
  27. return "python"
  28. # jupyter-based environments (e.g. jupyter itself, colab, kaggle, etc) have a connection file
  29. ip_kernel_app_connection_file = (
  30. (get_ipython().config.get("IPKernelApp", {}) or {})
  31. .get("connection_file", "")
  32. .lower()
  33. ) or (
  34. (get_ipython().config.get("ColabKernelApp", {}) or {})
  35. .get("connection_file", "")
  36. .lower()
  37. )
  38. if (
  39. ("terminal" in get_ipython().__module__)
  40. or ("jupyter" not in ip_kernel_app_connection_file)
  41. or ("spyder" in sys.modules)
  42. ):
  43. return "ipython"
  44. else:
  45. return "jupyter"
  46. def in_jupyter() -> bool:
  47. """Returns True if we're in a Jupyter notebook."""
  48. return _get_python_type() == "jupyter"
  49. def in_ipython() -> bool:
  50. """Returns True if we're running in IPython in the terminal."""
  51. return _get_python_type() == "ipython"
  52. def in_notebook() -> bool:
  53. """Returns True if we're running in Jupyter or IPython."""
  54. return _get_python_type() != "python"
  55. def in_vscode_notebook() -> bool:
  56. """Returns True if we're in a VSCode notebook."""
  57. try:
  58. from IPython import get_ipython
  59. except ModuleNotFoundError:
  60. return False
  61. ipython = get_ipython()
  62. if not ipython:
  63. return False
  64. return ipython.kernel.shell.user_ns.get("__vsc_ipynb_file__") is not None
  65. class ProgressWidget:
  66. """A simple wrapper to render a nice progress bar with a label."""
  67. def __init__(self, widgets, min, max):
  68. from IPython import display
  69. self._ipython_display = display
  70. self.widgets = widgets
  71. self._progress = widgets.FloatProgress(min=min, max=max)
  72. self._label = widgets.Label()
  73. self._widget = self.widgets.VBox([self._label, self._progress])
  74. self._displayed = False
  75. self._disabled = False
  76. def update(self, value: float, label: str) -> None:
  77. if self._disabled:
  78. return
  79. try:
  80. self._progress.value = value
  81. self._label.value = label
  82. if not self._displayed:
  83. self._displayed = True
  84. self._ipython_display.display(self._widget)
  85. except Exception:
  86. logger.exception("Error in ProgressWidget.update()")
  87. self._disabled = True
  88. wandb.termwarn(
  89. "Unable to render progress bar, see the user log for details"
  90. )
  91. def close(self) -> None:
  92. if self._disabled or not self._displayed:
  93. return
  94. self._widget.close()
  95. def jupyter_progress_bar(min: float = 0, max: float = 1.0) -> ProgressWidget | None:
  96. """Return an ipywidget progress bar or None if we can't import it."""
  97. widgets = wandb.util.get_module("ipywidgets")
  98. try:
  99. if widgets is None:
  100. # TODO: this currently works in iPython but it's deprecated since 4.0
  101. with warnings.catch_warnings():
  102. warnings.simplefilter("ignore")
  103. from IPython.html import widgets # type: ignore
  104. assert hasattr(widgets, "VBox")
  105. assert hasattr(widgets, "Label")
  106. assert hasattr(widgets, "FloatProgress")
  107. return ProgressWidget(widgets, min=min, max=max)
  108. except (ImportError, AssertionError):
  109. return None