__init__.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import warnings
  2. import numpy
  3. from sacred.dependencies import get_digest
  4. from sacred.observers import RunObserver
  5. import wandb
  6. class WandbObserver(RunObserver):
  7. """Log sacred experiment data to W&B.
  8. Args:
  9. Accepts all the arguments accepted by wandb.init().
  10. name — A display name for this run, which shows up in the UI and is editable, doesn't have to be unique
  11. notes — A multiline string description associated with the run
  12. config — a dictionary-like object to set as initial config
  13. project — the name of the project to which this run will belong
  14. tags — a list of strings to associate with this run as tags
  15. dir — the path to a directory where artifacts will be written (default: ./wandb)
  16. entity — the team posting this run (default: your username or your default team)
  17. job_type — the type of job you are logging, e.g. eval, worker, ps (default: training)
  18. save_code — save the main python or notebook file to wandb to enable diffing (default: editable from your settings page)
  19. group — a string by which to group other runs; see Grouping
  20. reinit — Shorthand for the reinit setting that defines what to do when `wandb.init()` is called while a run is active. See the setting's documentation.
  21. id — A unique ID for this run primarily used for Resuming. It must be globally unique, and if you delete a run you can't reuse the ID. Use the name field for a descriptive, useful name for the run. The ID cannot contain special characters.
  22. resume — if set to True, the run auto resumes; can also be a unique string for manual resuming; see Resuming (default: False)
  23. anonymous — can be "allow", "never", or "must". This enables or explicitly disables anonymous logging. (default: never)
  24. force — whether to force a user to be logged into wandb when running a script (default: False)
  25. magic — (bool, dict, or str, optional): magic configuration as bool, dict, json string, yaml filename. If set to True will attempt to auto-instrument your script. (default: None)
  26. sync_tensorboard — A boolean indicating whether or not copy all TensorBoard logs wandb; see Tensorboard (default: False)
  27. monitor_gym — A boolean indicating whether or not to log videos generated by OpenAI Gym; see Ray Tune (default: False)
  28. allow_val_change — whether to allow wandb.config values to change, by default we throw an exception if config values are overwritten. (default: False)
  29. Examples:
  30. Create sacred experiment::
  31. from wandb.sacred import WandbObserver
  32. ex.observers.append(WandbObserver(project='sacred_test',
  33. name='test1'))
  34. @ex.config
  35. def cfg():
  36. C = 1.0
  37. gamma = 0.7
  38. @ex.automain
  39. def run(C, gamma, _run):
  40. iris = datasets.load_iris()
  41. per = permutation(iris.target.size)
  42. iris.data = iris.data[per]
  43. iris.target = iris.target[per]
  44. clf = svm.SVC(C, 'rbf', gamma=gamma)
  45. clf.fit(iris.data[:90],
  46. iris.target[:90])
  47. return clf.score(iris.data[90:],
  48. iris.target[90:])
  49. """
  50. def __init__(self, **kwargs):
  51. self.run = wandb.init(**kwargs)
  52. self.resources = {}
  53. def started_event(
  54. self, ex_info, command, host_info, start_time, config, meta_info, _id
  55. ):
  56. # TODO: add the source code file
  57. # TODO: add dependencies and metadata.
  58. self.__update_config(config)
  59. def completed_event(self, stop_time, result):
  60. if result:
  61. if not isinstance(result, tuple):
  62. result = (
  63. result,
  64. ) # transform single result to tuple so that both single & multiple results use same code
  65. for i, r in enumerate(result):
  66. if isinstance(r, (float, int)):
  67. wandb.log({f"result_{i}": float(r)})
  68. elif isinstance(r, dict):
  69. wandb.log(r)
  70. elif isinstance(r, object):
  71. artifact = wandb.Artifact(f"result_{i}.pkl", type="result")
  72. artifact.add_file(r)
  73. self.run.log_artifact(artifact)
  74. elif isinstance(r, numpy.ndarray):
  75. wandb.log({f"result_{i}": wandb.Image(r)})
  76. else:
  77. warnings.warn(
  78. f"logging results does not support type '{type(r)}' results. Ignoring this result",
  79. stacklevel=2,
  80. )
  81. def artifact_event(self, name, filename, metadata=None, content_type=None):
  82. if content_type is None:
  83. content_type = "file"
  84. artifact = wandb.Artifact(name, type=content_type)
  85. artifact.add_file(filename)
  86. self.run.log_artifact(artifact)
  87. def resource_event(self, filename):
  88. """TODO: Maintain resources list."""
  89. if filename not in self.resources:
  90. md5 = get_digest(filename)
  91. self.resources[filename] = md5
  92. def log_metrics(self, metrics_by_name, info):
  93. for metric_name, metric_ptr in metrics_by_name.items():
  94. for _step, value in zip(metric_ptr["steps"], metric_ptr["values"]):
  95. if isinstance(value, numpy.ndarray):
  96. wandb.log({metric_name: wandb.Image(value)})
  97. else:
  98. wandb.log({metric_name: value})
  99. def __update_config(self, config):
  100. for k, v in config.items():
  101. self.run.config[k] = v
  102. self.run.config["resources"] = []