mlflow.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import logging
  2. import os
  3. from copy import deepcopy
  4. from typing import TYPE_CHECKING, Dict, Optional
  5. from packaging import version
  6. from ray._private.dict import flatten_dict
  7. if TYPE_CHECKING:
  8. from mlflow.entities import Run
  9. from mlflow.tracking import MlflowClient
  10. logger = logging.getLogger(__name__)
  11. class _MLflowLoggerUtil:
  12. """Util class for setting up and logging to MLflow.
  13. Use this util for any library that needs MLflow logging/tracking logic
  14. such as Ray Tune or Ray Train.
  15. """
  16. def __init__(self):
  17. import mlflow
  18. self._mlflow = mlflow
  19. self.experiment_id = None
  20. def __deepcopy__(self, memo=None):
  21. # mlflow is a module, and thus cannot be copied
  22. _mlflow = self._mlflow
  23. self.__dict__.pop("_mlflow")
  24. dict_copy = deepcopy(self.__dict__, memo)
  25. copied_object = _MLflowLoggerUtil()
  26. copied_object.__dict__.update(dict_copy)
  27. self._mlflow = _mlflow
  28. copied_object._mlflow = _mlflow
  29. return copied_object
  30. def setup_mlflow(
  31. self,
  32. tracking_uri: Optional[str] = None,
  33. registry_uri: Optional[str] = None,
  34. experiment_id: Optional[str] = None,
  35. experiment_name: Optional[str] = None,
  36. tracking_token: Optional[str] = None,
  37. artifact_location: Optional[str] = None,
  38. create_experiment_if_not_exists: bool = True,
  39. ):
  40. """
  41. Sets up MLflow.
  42. Sets the Mlflow tracking uri & token, and registry URI. Also sets
  43. the MLflow experiment that the logger should use, and possibly
  44. creates new experiment if it does not exist.
  45. Args:
  46. tracking_uri: The tracking URI for the MLflow tracking
  47. server.
  48. registry_uri: The registry URI for the MLflow model registry.
  49. experiment_id: The id of an already existing MLflow
  50. experiment to use for logging. If None is passed in
  51. here and the MFLOW_EXPERIMENT_ID is not set, or the
  52. experiment with this id does not exist,
  53. ``experiment_name`` will be used instead. This argument takes
  54. precedence over ``experiment_name`` if both are passed in.
  55. experiment_name: The experiment name to use for logging.
  56. If None is passed in here, the MLFLOW_EXPERIMENT_NAME environment
  57. variable is used to determine the experiment name.
  58. If the experiment with the name already exists with MLflow,
  59. it will be reused. If not, a new experiment will be created
  60. with the provided name if
  61. ``create_experiment_if_not_exists`` is set to True.
  62. artifact_location: The location to store run artifacts.
  63. If not provided, MLFlow picks an appropriate default.
  64. Ignored if experiment already exists.
  65. tracking_token: Tracking token used to authenticate with MLflow.
  66. create_experiment_if_not_exists: Whether to create an
  67. experiment with the provided name if it does not already
  68. exist. Defaults to True.
  69. Returns:
  70. Whether setup is successful.
  71. """
  72. if tracking_token:
  73. os.environ["MLFLOW_TRACKING_TOKEN"] = tracking_token
  74. self._mlflow.set_tracking_uri(tracking_uri)
  75. self._mlflow.set_registry_uri(registry_uri)
  76. # First check experiment_id.
  77. experiment_id = (
  78. experiment_id
  79. if experiment_id is not None
  80. else os.environ.get("MLFLOW_EXPERIMENT_ID")
  81. )
  82. if experiment_id is not None:
  83. from mlflow.exceptions import MlflowException
  84. try:
  85. self._mlflow.get_experiment(experiment_id=experiment_id)
  86. logger.debug(
  87. f"Experiment with provided id {experiment_id} "
  88. "exists. Setting that as the experiment."
  89. )
  90. self.experiment_id = experiment_id
  91. return
  92. except MlflowException:
  93. pass
  94. # Then check experiment_name.
  95. experiment_name = (
  96. experiment_name
  97. if experiment_name is not None
  98. else os.environ.get("MLFLOW_EXPERIMENT_NAME")
  99. )
  100. if experiment_name is not None and self._mlflow.get_experiment_by_name(
  101. name=experiment_name
  102. ):
  103. logger.debug(
  104. f"Experiment with provided name {experiment_name} "
  105. "exists. Setting that as the experiment."
  106. )
  107. self.experiment_id = self._mlflow.get_experiment_by_name(
  108. experiment_name
  109. ).experiment_id
  110. return
  111. # An experiment with the provided id or name does not exist.
  112. # Create a new experiment if applicable.
  113. if experiment_name and create_experiment_if_not_exists:
  114. logger.debug(
  115. "Existing experiment not found. Creating new "
  116. f"experiment with name: {experiment_name}"
  117. )
  118. self.experiment_id = self._mlflow.create_experiment(
  119. name=experiment_name, artifact_location=artifact_location
  120. )
  121. return
  122. if create_experiment_if_not_exists:
  123. raise ValueError(
  124. f"Experiment with the provided experiment_id: "
  125. f"{experiment_id} does not exist and no "
  126. f"experiment_name provided. At least one of "
  127. f"these has to be provided."
  128. )
  129. else:
  130. raise ValueError(
  131. f"Experiment with the provided experiment_id: "
  132. f"{experiment_id} or experiment_name: "
  133. f"{experiment_name} does not exist. Please "
  134. f"create an MLflow experiment and provide "
  135. f"either its id or name."
  136. )
  137. def _parse_dict(self, dict_to_log: Dict) -> Dict:
  138. """Parses provided dict to convert all values to float.
  139. MLflow can only log metrics that are floats. This does not apply to
  140. logging parameters or artifacts.
  141. Args:
  142. dict_to_log: The dictionary containing the metrics to log.
  143. Returns:
  144. A dictionary containing the metrics to log with all values being
  145. converted to floats, or skipped if not able to be converted.
  146. """
  147. new_dict = {}
  148. for key, value in dict_to_log.items():
  149. try:
  150. value = float(value)
  151. new_dict[key] = value
  152. except (ValueError, TypeError):
  153. logger.debug(
  154. "Cannot log key {} with value {} since the "
  155. "value cannot be converted to float.".format(key, value)
  156. )
  157. continue
  158. return new_dict
  159. def start_run(
  160. self,
  161. run_name: Optional[str] = None,
  162. tags: Optional[Dict] = None,
  163. set_active: bool = False,
  164. ) -> "Run":
  165. """Starts a new run and possibly sets it as the active run.
  166. Args:
  167. tags: Tags to set for the new run.
  168. set_active: Whether to set the new run as the active run.
  169. If an active run already exists, then that run is returned.
  170. Returns:
  171. The newly created MLflow run.
  172. """
  173. import mlflow
  174. from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
  175. if tags is None:
  176. tags = {}
  177. if set_active:
  178. return self._start_active_run(run_name=run_name, tags=tags)
  179. client = self._get_client()
  180. # If `mlflow==1.30.0` and we don't use `run_name`, then MLflow might error. For
  181. # more information, see #29749.
  182. if version.parse(mlflow.__version__) >= version.parse("1.30.0"):
  183. run = client.create_run(
  184. run_name=run_name, experiment_id=self.experiment_id, tags=tags
  185. )
  186. else:
  187. tags[MLFLOW_RUN_NAME] = run_name
  188. run = client.create_run(experiment_id=self.experiment_id, tags=tags)
  189. return run
  190. def _start_active_run(
  191. self, run_name: Optional[str] = None, tags: Optional[Dict] = None
  192. ) -> "Run":
  193. """Starts a run and sets it as the active run if one does not exist.
  194. If an active run already exists, then returns it.
  195. """
  196. active_run = self._mlflow.active_run()
  197. if active_run:
  198. return active_run
  199. return self._mlflow.start_run(
  200. run_name=run_name, experiment_id=self.experiment_id, tags=tags
  201. )
  202. def _run_exists(self, run_id: str) -> bool:
  203. """Check if run with the provided id exists."""
  204. from mlflow.exceptions import MlflowException
  205. try:
  206. self._mlflow.get_run(run_id=run_id)
  207. return True
  208. except MlflowException:
  209. return False
  210. def _get_client(self) -> "MlflowClient":
  211. """Returns an ml.tracking.MlflowClient instance to use for logging."""
  212. tracking_uri = self._mlflow.get_tracking_uri()
  213. registry_uri = self._mlflow.get_registry_uri()
  214. from mlflow.tracking import MlflowClient
  215. return MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri)
  216. def log_params(self, params_to_log: Dict, run_id: Optional[str] = None):
  217. """Logs the provided parameters to the run specified by run_id.
  218. If no ``run_id`` is passed in, then logs to the current active run.
  219. If there is not active run, then creates a new run and sets it as
  220. the active run.
  221. Args:
  222. params_to_log: Dictionary of parameters to log.
  223. run_id (Optional[str]): The ID of the run to log to.
  224. """
  225. params_to_log = flatten_dict(params_to_log)
  226. if run_id and self._run_exists(run_id):
  227. client = self._get_client()
  228. for key, value in params_to_log.items():
  229. client.log_param(run_id=run_id, key=key, value=value)
  230. else:
  231. for key, value in params_to_log.items():
  232. self._mlflow.log_param(key=key, value=value)
  233. def log_metrics(self, step, metrics_to_log: Dict, run_id: Optional[str] = None):
  234. """Logs the provided metrics to the run specified by run_id.
  235. If no ``run_id`` is passed in, then logs to the current active run.
  236. If there is not active run, then creates a new run and sets it as
  237. the active run.
  238. Args:
  239. metrics_to_log: Dictionary of metrics to log.
  240. run_id (Optional[str]): The ID of the run to log to.
  241. """
  242. metrics_to_log = flatten_dict(metrics_to_log)
  243. metrics_to_log = self._parse_dict(metrics_to_log)
  244. if run_id and self._run_exists(run_id):
  245. client = self._get_client()
  246. for key, value in metrics_to_log.items():
  247. client.log_metric(run_id=run_id, key=key, value=value, step=step)
  248. else:
  249. for key, value in metrics_to_log.items():
  250. self._mlflow.log_metric(key=key, value=value, step=step)
  251. def save_artifacts(self, dir: str, run_id: Optional[str] = None):
  252. """Saves directory as artifact to the run specified by run_id.
  253. If no ``run_id`` is passed in, then saves to the current active run.
  254. If there is not active run, then creates a new run and sets it as
  255. the active run.
  256. Args:
  257. dir: Path to directory containing the files to save.
  258. run_id (Optional[str]): The ID of the run to log to.
  259. """
  260. if run_id and self._run_exists(run_id):
  261. client = self._get_client()
  262. client.log_artifacts(run_id=run_id, local_dir=dir)
  263. else:
  264. self._mlflow.log_artifacts(local_dir=dir)
  265. def end_run(self, status: Optional[str] = None, run_id=None):
  266. """Terminates the run specified by run_id.
  267. If no ``run_id`` is passed in, then terminates the
  268. active run if one exists.
  269. Args:
  270. status (Optional[str]): The status to set when terminating the run.
  271. run_id (Optional[str]): The ID of the run to terminate.
  272. """
  273. if (
  274. run_id
  275. and self._run_exists(run_id)
  276. and not (
  277. self._mlflow.active_run()
  278. and self._mlflow.active_run().info.run_id == run_id
  279. )
  280. ):
  281. client = self._get_client()
  282. client.set_terminated(run_id=run_id, status=status)
  283. else:
  284. self._mlflow.end_run(status=status)