util.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import importlib
  2. import logging
  3. import sys
  4. import textwrap
  5. from functools import wraps
  6. from typing import Any, Callable, Iterable, Optional, TypeVar, Union
  7. from packaging.version import Version
  8. from ray._private.thirdparty.tabulate.tabulate import tabulate
  9. from ray.util.annotations import DeveloperAPI
  10. from ray.widgets import Template
  11. logger = logging.getLogger(__name__)
  12. F = TypeVar("F", bound=Callable[..., Any])
  13. @DeveloperAPI
  14. def make_table_html_repr(
  15. obj: Any, title: Optional[str] = None, max_height: str = "none"
  16. ) -> str:
  17. """Generate a generic html repr using a table.
  18. Args:
  19. obj: Object for which a repr is to be generated
  20. title: If present, a title for the section is included
  21. max_height: Maximum height of the table; valid values
  22. are given by the max-height CSS property
  23. Returns:
  24. HTML representation of the object
  25. """
  26. data = {}
  27. for k, v in vars(obj).items():
  28. if isinstance(v, (str, bool, int, float)):
  29. data[k] = str(v)
  30. elif isinstance(v, dict) or hasattr(v, "__dict__"):
  31. data[k] = Template("scrollableTable.html.j2").render(
  32. table=tabulate(
  33. v.items() if isinstance(v, dict) else vars(v).items(),
  34. tablefmt="html",
  35. showindex=False,
  36. headers=["Setting", "Value"],
  37. ),
  38. max_height="none",
  39. )
  40. table = Template("scrollableTable.html.j2").render(
  41. table=tabulate(
  42. data.items(),
  43. tablefmt="unsafehtml",
  44. showindex=False,
  45. headers=["Setting", "Value"],
  46. ),
  47. max_height=max_height,
  48. )
  49. if title:
  50. content = Template("title_data.html.j2").render(title=title, data=table)
  51. else:
  52. content = table
  53. return content
  54. def _has_missing(
  55. *deps: Iterable[Union[str, Optional[str]]], message: Optional[str] = None
  56. ):
  57. """Return a list of missing dependencies.
  58. Args:
  59. deps: Dependencies to check for
  60. message: Message to be emitted if a dependency isn't found
  61. Returns:
  62. A list of dependencies which can't be found, if any
  63. """
  64. missing = []
  65. for (lib, _) in deps:
  66. if importlib.util.find_spec(lib) is None:
  67. missing.append(lib)
  68. if missing:
  69. if not message:
  70. message = f"Run `pip install {' '.join(missing)}` for rich notebook output."
  71. # stacklevel=3: First level is this function, then ensure_notebook_deps,
  72. # then the actual function affected.
  73. logger.info(f"Missing packages: {missing}. {message}", stacklevel=3)
  74. return missing
  75. def _has_outdated(
  76. *deps: Iterable[Union[str, Optional[str]]], message: Optional[str] = None
  77. ):
  78. outdated = []
  79. for (lib, version) in deps:
  80. try:
  81. module = importlib.import_module(lib)
  82. if version and Version(module.__version__) < Version(version):
  83. outdated.append([lib, version, module.__version__])
  84. except ImportError:
  85. pass
  86. if outdated:
  87. outdated_strs = []
  88. install_args = []
  89. for lib, version, installed in outdated:
  90. outdated_strs.append(f"{lib}=={installed} found, needs {lib}>={version}")
  91. install_args.append(f"{lib}>={version}")
  92. outdated_str = textwrap.indent("\n".join(outdated_strs), " ")
  93. install_str = " ".join(install_args)
  94. if not message:
  95. message = f"Run `pip install -U {install_str}` for rich notebook output."
  96. # stacklevel=3: First level is this function, then ensure_notebook_deps,
  97. # then the actual function affected.
  98. logger.info(f"Outdated packages:\n{outdated_str}\n{message}", stacklevel=3)
  99. return outdated
  100. @DeveloperAPI
  101. def repr_with_fallback(
  102. *notebook_deps: Iterable[Union[str, Optional[str]]]
  103. ) -> Callable[[F], F]:
  104. """Decorator which strips rich notebook output from mimebundles in certain cases.
  105. Fallback to plaintext and don't use rich output in the following cases:
  106. 1. In a notebook environment and the appropriate dependencies are not installed.
  107. 2. In a ipython shell environment.
  108. 3. In Google Colab environment.
  109. See https://github.com/googlecolab/colabtools/ issues/60 for more information
  110. about the status of this issue.
  111. Args:
  112. notebook_deps: The required dependencies and version for notebook environment.
  113. Returns:
  114. A function that returns the usual _repr_mimebundle_, unless any of the 3
  115. conditions above hold, in which case it returns a mimebundle that only contains
  116. a single text/plain mimetype.
  117. """
  118. message = (
  119. "Run `pip install -U ipywidgets`, then restart "
  120. "the notebook server for rich notebook output."
  121. )
  122. if _can_display_ipywidgets(*notebook_deps, message=message):
  123. def wrapper(func: F) -> F:
  124. @wraps(func)
  125. def wrapped(self, *args, **kwargs):
  126. return func(self, *args, **kwargs)
  127. return wrapped
  128. else:
  129. def wrapper(func: F) -> F:
  130. @wraps(func)
  131. def wrapped(self, *args, **kwargs):
  132. return {"text/plain": repr(self)}
  133. return wrapped
  134. return wrapper
  135. def _get_ipython_shell_name() -> str:
  136. if "IPython" in sys.modules:
  137. from IPython import get_ipython
  138. return get_ipython().__class__.__name__
  139. return ""
  140. def _can_display_ipywidgets(*deps, message) -> bool:
  141. # Default to safe behavior: only display widgets if running in a notebook
  142. # that has valid dependencies
  143. if in_notebook() and not (
  144. _has_missing(*deps, message=message) or _has_outdated(*deps, message=message)
  145. ):
  146. return True
  147. return False
  148. @DeveloperAPI
  149. def in_notebook(shell_name: Optional[str] = None) -> bool:
  150. """Return whether we are in a Jupyter notebook or qtconsole."""
  151. if not shell_name:
  152. shell_name = _get_ipython_shell_name()
  153. return shell_name == "ZMQInteractiveShell"
  154. @DeveloperAPI
  155. def in_ipython_shell(shell_name: Optional[str] = None) -> bool:
  156. """Return whether we are in a terminal running IPython"""
  157. if not shell_name:
  158. shell_name = _get_ipython_shell_name()
  159. return shell_name == "TerminalInteractiveShell"