__init__.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. """Use wandb to track machine learning work.
  2. Train and fine-tune models, manage models from experimentation to production.
  3. For guides and examples, see https://docs.wandb.ai.
  4. For scripts and interactive notebooks, see https://github.com/wandb/examples.
  5. For reference documentation, see https://docs.wandb.ai/models/ref/python.
  6. """
  7. from __future__ import annotations
  8. __version__ = "0.26.0"
  9. from wandb.errors import Error
  10. # This needs to be early as other modules call it.
  11. from wandb.errors.term import termsetup, termlog, termerror, termwarn
  12. # Configure the logger as early as possible for consistent behavior.
  13. from wandb.sdk.lib import wb_logging as _wb_logging
  14. _wb_logging.configure_wandb_logger()
  15. from wandb import sdk as wandb_sdk
  16. import wandb
  17. wandb.wandb_lib = wandb_sdk.lib # type: ignore
  18. init = wandb_sdk.init
  19. setup = wandb_sdk.setup
  20. attach = _attach = wandb_sdk._attach
  21. teardown = _teardown = wandb_sdk.teardown
  22. finish = wandb_sdk.finish
  23. join = finish
  24. login = wandb_sdk.login
  25. helper = wandb_sdk.helper
  26. sweep = wandb_sdk.sweep
  27. controller = wandb_sdk.controller
  28. require = wandb_sdk.require
  29. Artifact = wandb_sdk.Artifact
  30. AlertLevel = wandb_sdk.AlertLevel
  31. Settings = wandb_sdk.Settings
  32. Config = wandb_sdk.Config
  33. from wandb.apis import InternalApi, PublicApi
  34. from wandb.errors import CommError, UsageError
  35. from wandb.sdk.lib import preinit as _preinit
  36. from wandb.sdk.lib import lazyloader as _lazyloader
  37. from wandb.integration.torch import wandb_torch
  38. from wandb.sdk.data_types._private import _cleanup_media_tmp_dir
  39. _cleanup_media_tmp_dir()
  40. from wandb.data_types import Graph
  41. from wandb.data_types import Image
  42. from wandb.data_types import Plotly
  43. from wandb.data_types import Video
  44. from wandb.data_types import Audio
  45. from wandb.data_types import Table
  46. from wandb.data_types import Html
  47. from wandb.data_types import box3d
  48. from wandb.data_types import Object3D
  49. from wandb.data_types import Molecule
  50. from wandb.data_types import Histogram
  51. from wandb.data_types import Classes
  52. from wandb.data_types import JoinedTable
  53. from wandb.wandb_agent import agent
  54. from wandb.plot import visualize, plot_table
  55. from wandb.integration.sagemaker import sagemaker_auth
  56. from wandb.sdk.internal import profiler
  57. from wandb.sdk.wandb_run import Run
  58. # Artifact import types
  59. from wandb.sdk.artifacts.artifact_ttl import ArtifactTTL
  60. # globals
  61. Api = PublicApi
  62. api = InternalApi()
  63. run: Run | None = None
  64. config = _preinit.PreInitObject("wandb.config", wandb_sdk.wandb_config.Config)
  65. summary = _preinit.PreInitObject("wandb.summary", wandb_sdk.wandb_summary.Summary)
  66. log = _preinit.PreInitCallable("wandb.log", Run.log) # type: ignore
  67. watch = _preinit.PreInitCallable("wandb.watch", Run.watch) # type: ignore
  68. unwatch = _preinit.PreInitCallable("wandb.unwatch", Run.unwatch) # type: ignore
  69. save = _preinit.PreInitCallable("wandb.save", Run.save) # type: ignore
  70. restore = wandb_sdk.wandb_run.restore
  71. use_artifact = _preinit.PreInitCallable(
  72. "wandb.use_artifact", Run.use_artifact # type: ignore
  73. )
  74. log_artifact = _preinit.PreInitCallable(
  75. "wandb.log_artifact", Run.log_artifact # type: ignore
  76. )
  77. log_model = _preinit.PreInitCallable(
  78. "wandb.log_model", Run.log_model # type: ignore
  79. )
  80. use_model = _preinit.PreInitCallable(
  81. "wandb.use_model", Run.use_model # type: ignore
  82. )
  83. link_model = _preinit.PreInitCallable(
  84. "wandb.link_model", Run.link_model # type: ignore
  85. )
  86. define_metric = _preinit.PreInitCallable(
  87. "wandb.define_metric", Run.define_metric # type: ignore
  88. )
  89. mark_preempting = _preinit.PreInitCallable(
  90. "wandb.mark_preempting", Run.mark_preempting # type: ignore
  91. )
  92. alert = _preinit.PreInitCallable("wandb.alert", Run.alert) # type: ignore
  93. pin_config_keys = _preinit.PreInitCallable(
  94. "wandb.pin_config_keys", Run.pin_config_keys # type: ignore
  95. )
  96. # record of patched libraries
  97. patched = {"tensorboard": [], "keras": [], "gym": []} # type: ignore
  98. keras = _lazyloader.LazyLoader("wandb.keras", globals(), "wandb.integration.keras")
  99. sklearn = _lazyloader.LazyLoader("wandb.sklearn", globals(), "wandb.sklearn")
  100. tensorflow = _lazyloader.LazyLoader(
  101. "wandb.tensorflow", globals(), "wandb.integration.tensorflow"
  102. )
  103. xgboost = _lazyloader.LazyLoader(
  104. "wandb.xgboost", globals(), "wandb.integration.xgboost"
  105. )
  106. catboost = _lazyloader.LazyLoader(
  107. "wandb.catboost", globals(), "wandb.integration.catboost"
  108. )
  109. tensorboard = _lazyloader.LazyLoader(
  110. "wandb.tensorboard", globals(), "wandb.integration.tensorboard"
  111. )
  112. gym = _lazyloader.LazyLoader("wandb.gym", globals(), "wandb.integration.gym")
  113. lightgbm = _lazyloader.LazyLoader(
  114. "wandb.lightgbm", globals(), "wandb.integration.lightgbm"
  115. )
  116. jupyter = _lazyloader.LazyLoader("wandb.jupyter", globals(), "wandb.jupyter")
  117. sacred = _lazyloader.LazyLoader("wandb.sacred", globals(), "wandb.integration.sacred")
  118. def ensure_configured():
  119. global api
  120. api = InternalApi()
  121. def set_trace():
  122. import pdb
  123. pdb.set_trace()
  124. if wandb_sdk.lib.ipython.in_notebook():
  125. from IPython import get_ipython # type: ignore[import-not-found]
  126. jupyter._load_ipython_extension(get_ipython())
  127. if "dev" in __version__:
  128. import wandb.env
  129. import os
  130. # Disable error reporting in dev versions.
  131. os.environ[wandb.env.ERROR_REPORTING] = os.environ.get(
  132. wandb.env.ERROR_REPORTING,
  133. "false",
  134. )
  135. __all__ = (
  136. "__version__",
  137. "init",
  138. "finish",
  139. "setup",
  140. "save",
  141. "sweep",
  142. "controller",
  143. "agent",
  144. "config",
  145. "log",
  146. "summary",
  147. "join",
  148. "Api",
  149. "Graph",
  150. "Image",
  151. "Plotly",
  152. "Video",
  153. "Audio",
  154. "Table",
  155. "Html",
  156. "box3d",
  157. "Object3D",
  158. "Molecule",
  159. "Histogram",
  160. "ArtifactTTL",
  161. "log_artifact",
  162. "use_artifact",
  163. "log_model",
  164. "use_model",
  165. "link_model",
  166. "define_metric",
  167. "watch",
  168. "unwatch",
  169. "plot_table",
  170. "Run",
  171. )