comet.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import os
  2. from pathlib import Path
  3. from typing import Dict, List
  4. import pyarrow.fs
  5. from ray.tune.experiment import Trial
  6. from ray.tune.logger import LoggerCallback
  7. from ray.tune.utils import flatten_dict
  8. def _import_comet():
  9. """Try importing comet_ml.
  10. Used to check if comet_ml is installed and, otherwise, pass an informative
  11. error message.
  12. """
  13. if "COMET_DISABLE_AUTO_LOGGING" not in os.environ:
  14. os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
  15. try:
  16. import comet_ml # noqa: F401
  17. except ImportError:
  18. raise RuntimeError("pip install 'comet-ml' to use CometLoggerCallback")
  19. return comet_ml
  20. class CometLoggerCallback(LoggerCallback):
  21. """CometLoggerCallback for logging Tune results to Comet.
  22. Comet (https://comet.ml/site/) is a tool to manage and optimize the
  23. entire ML lifecycle, from experiment tracking, model optimization
  24. and dataset versioning to model production monitoring.
  25. This Ray Tune ``LoggerCallback`` sends metrics and parameters to
  26. Comet for tracking.
  27. In order to use the CometLoggerCallback you must first install Comet
  28. via ``pip install comet_ml``
  29. Then set the following environment variables
  30. ``export COMET_API_KEY=<Your API Key>``
  31. Alternatively, you can also pass in your API Key as an argument to the
  32. CometLoggerCallback constructor.
  33. ``CometLoggerCallback(api_key=<Your API Key>)``
  34. Args:
  35. online: Whether to make use of an Online or
  36. Offline Experiment. Defaults to True.
  37. tags: Tags to add to the logged Experiment.
  38. Defaults to None.
  39. save_checkpoints: If ``True``, model checkpoints will be saved to
  40. Comet ML as artifacts. Defaults to ``False``.
  41. **experiment_kwargs: Other keyword arguments will be passed to the
  42. constructor for comet_ml.Experiment (or OfflineExperiment if
  43. online=False).
  44. Please consult the Comet ML documentation for more information on the
  45. Experiment and OfflineExperiment classes: https://comet.ml/site/
  46. Example:
  47. .. code-block:: python
  48. from ray.air.integrations.comet import CometLoggerCallback
  49. tune.run(
  50. train,
  51. config=config
  52. callbacks=[CometLoggerCallback(
  53. True,
  54. ['tag1', 'tag2'],
  55. workspace='my_workspace',
  56. project_name='my_project_name'
  57. )]
  58. )
  59. """
  60. # Do not enable these auto log options unless overridden
  61. _exclude_autolog = [
  62. "auto_output_logging",
  63. "log_git_metadata",
  64. "log_git_patch",
  65. "log_env_cpu",
  66. "log_env_gpu",
  67. ]
  68. # Do not log these metrics.
  69. _exclude_results = ["done", "should_checkpoint"]
  70. # These values should be logged as system info instead of metrics.
  71. _system_results = ["node_ip", "hostname", "pid", "date"]
  72. # These values should be logged as "Other" instead of as metrics.
  73. _other_results = ["trial_id", "experiment_id", "experiment_tag"]
  74. _episode_results = ["hist_stats/episode_reward", "hist_stats/episode_lengths"]
  75. def __init__(
  76. self,
  77. online: bool = True,
  78. tags: List[str] = None,
  79. save_checkpoints: bool = False,
  80. **experiment_kwargs,
  81. ):
  82. _import_comet()
  83. self.online = online
  84. self.tags = tags
  85. self.save_checkpoints = save_checkpoints
  86. self.experiment_kwargs = experiment_kwargs
  87. # Disable the specific autologging features that cause throttling.
  88. self._configure_experiment_defaults()
  89. # Mapping from trial to experiment object.
  90. self._trial_experiments = {}
  91. self._to_exclude = self._exclude_results.copy()
  92. self._to_system = self._system_results.copy()
  93. self._to_other = self._other_results.copy()
  94. self._to_episodes = self._episode_results.copy()
  95. def _configure_experiment_defaults(self):
  96. """Disable the specific autologging features that cause throttling."""
  97. for option in self._exclude_autolog:
  98. if not self.experiment_kwargs.get(option):
  99. self.experiment_kwargs[option] = False
  100. def _check_key_name(self, key: str, item: str) -> bool:
  101. """
  102. Check if key argument is equal to item argument or starts with item and
  103. a forward slash. Used for parsing trial result dictionary into ignored
  104. keys, system metrics, episode logs, etc.
  105. """
  106. return key.startswith(item + "/") or key == item
  107. def log_trial_start(self, trial: "Trial"):
  108. """
  109. Initialize an Experiment (or OfflineExperiment if self.online=False)
  110. and start logging to Comet.
  111. Args:
  112. trial: Trial object.
  113. """
  114. _import_comet() # is this necessary?
  115. from comet_ml import Experiment, OfflineExperiment
  116. from comet_ml.config import set_global_experiment
  117. if trial not in self._trial_experiments:
  118. experiment_cls = Experiment if self.online else OfflineExperiment
  119. experiment = experiment_cls(**self.experiment_kwargs)
  120. self._trial_experiments[trial] = experiment
  121. # Set global experiment to None to allow for multiple experiments.
  122. set_global_experiment(None)
  123. else:
  124. experiment = self._trial_experiments[trial]
  125. experiment.set_name(str(trial))
  126. experiment.add_tags(self.tags)
  127. experiment.log_other("Created from", "Ray")
  128. config = trial.config.copy()
  129. config.pop("callbacks", None)
  130. experiment.log_parameters(config)
  131. def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
  132. """
  133. Log the current result of a Trial upon each iteration.
  134. """
  135. if trial not in self._trial_experiments:
  136. self.log_trial_start(trial)
  137. experiment = self._trial_experiments[trial]
  138. step = result["training_iteration"]
  139. config_update = result.pop("config", {}).copy()
  140. config_update.pop("callbacks", None) # Remove callbacks
  141. for k, v in config_update.items():
  142. if isinstance(v, dict):
  143. experiment.log_parameters(flatten_dict({k: v}, "/"), step=step)
  144. else:
  145. experiment.log_parameter(k, v, step=step)
  146. other_logs = {}
  147. metric_logs = {}
  148. system_logs = {}
  149. episode_logs = {}
  150. flat_result = flatten_dict(result, delimiter="/")
  151. for k, v in flat_result.items():
  152. if any(self._check_key_name(k, item) for item in self._to_exclude):
  153. continue
  154. if any(self._check_key_name(k, item) for item in self._to_other):
  155. other_logs[k] = v
  156. elif any(self._check_key_name(k, item) for item in self._to_system):
  157. system_logs[k] = v
  158. elif any(self._check_key_name(k, item) for item in self._to_episodes):
  159. episode_logs[k] = v
  160. else:
  161. metric_logs[k] = v
  162. experiment.log_others(other_logs)
  163. experiment.log_metrics(metric_logs, step=step)
  164. for k, v in system_logs.items():
  165. experiment.log_system_info(k, v)
  166. for k, v in episode_logs.items():
  167. experiment.log_curve(k, x=range(len(v)), y=v, step=step)
  168. def log_trial_save(self, trial: "Trial"):
  169. comet_ml = _import_comet()
  170. if self.save_checkpoints and trial.checkpoint:
  171. experiment = self._trial_experiments[trial]
  172. artifact = comet_ml.Artifact(
  173. name=f"checkpoint_{(str(trial))}", artifact_type="model"
  174. )
  175. checkpoint_root = None
  176. if isinstance(trial.checkpoint.filesystem, pyarrow.fs.LocalFileSystem):
  177. checkpoint_root = trial.checkpoint.path
  178. # Todo: For other filesystems, we may want to use
  179. # artifact.add_remote() instead. However, this requires a full
  180. # URI. We can add this once we have a way to retrieve it.
  181. # Walk through checkpoint directory and add all files to artifact
  182. if checkpoint_root:
  183. for root, dirs, files in os.walk(checkpoint_root):
  184. rel_root = os.path.relpath(root, checkpoint_root)
  185. for file in files:
  186. local_file = Path(checkpoint_root, rel_root, file).as_posix()
  187. logical_path = Path(rel_root, file).as_posix()
  188. # Strip leading `./`
  189. if logical_path.startswith("./"):
  190. logical_path = logical_path[2:]
  191. artifact.add(local_file, logical_path=logical_path)
  192. experiment.log_artifact(artifact)
  193. def log_trial_end(self, trial: "Trial", failed: bool = False):
  194. self._trial_experiments[trial].end()
  195. del self._trial_experiments[trial]
  196. def __del__(self):
  197. for trial, experiment in self._trial_experiments.items():
  198. experiment.end()
  199. self._trial_experiments = {}