mlflow.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import logging
  2. from types import ModuleType
  3. from typing import Dict, Optional, Union
  4. import ray
  5. from ray.air._internal import usage as air_usage
  6. from ray.air._internal.mlflow import _MLflowLoggerUtil
  7. from ray.air.constants import TRAINING_ITERATION
  8. from ray.tune.experiment import Trial
  9. from ray.tune.logger import LoggerCallback
  10. from ray.tune.result import TIMESTEPS_TOTAL
  11. from ray.tune.trainable.trainable_fn_utils import _in_tune_session
  12. from ray.util.annotations import PublicAPI
  13. try:
  14. import mlflow
  15. except ImportError:
  16. mlflow = None
  17. logger = logging.getLogger(__name__)
  18. class _NoopModule:
  19. def __getattr__(self, item):
  20. return _NoopModule()
  21. def __call__(self, *args, **kwargs):
  22. return None
  23. @PublicAPI(stability="alpha")
  24. def setup_mlflow(
  25. config: Optional[Dict] = None,
  26. tracking_uri: Optional[str] = None,
  27. registry_uri: Optional[str] = None,
  28. experiment_id: Optional[str] = None,
  29. experiment_name: Optional[str] = None,
  30. tracking_token: Optional[str] = None,
  31. artifact_location: Optional[str] = None,
  32. run_name: Optional[str] = None,
  33. create_experiment_if_not_exists: bool = False,
  34. tags: Optional[Dict] = None,
  35. rank_zero_only: bool = True,
  36. ) -> Union[ModuleType, _NoopModule]:
  37. """Set up a MLflow session.
  38. This function can be used to initialize an MLflow session in a
  39. (distributed) training or tuning run. The session will be created on the trainable.
  40. By default, the MLflow experiment ID is the Ray trial ID and the
  41. MLlflow experiment name is the Ray trial name. These settings can be overwritten by
  42. passing the respective keyword arguments.
  43. The ``config`` dict is automatically logged as the run parameters (excluding the
  44. mlflow settings).
  45. In distributed training with Ray Train, only the zero-rank worker will initialize
  46. mlflow. All other workers will return a noop client, so that logging is not
  47. duplicated in a distributed run. This can be disabled by passing
  48. ``rank_zero_only=False``, which will then initialize mlflow in every training
  49. worker. Note: for Ray Tune, there's no concept of worker ranks, so the `rank_zero_only` is ignored.
  50. This function will return the ``mlflow`` module or a noop module for
  51. non-rank zero workers ``if rank_zero_only=True``. By using
  52. ``mlflow = setup_mlflow(config)`` you can ensure that only the rank zero worker
  53. calls the mlflow API.
  54. Args:
  55. config: Configuration dict to be logged to mlflow as parameters.
  56. tracking_uri: The tracking URI for MLflow tracking. If using
  57. Tune in a multi-node setting, make sure to use a remote server for
  58. tracking.
  59. registry_uri: The registry URI for the MLflow model registry.
  60. experiment_id: The id of an already created MLflow experiment.
  61. All logs from all trials in ``tune.Tuner()`` will be reported to this
  62. experiment. If this is not provided or the experiment with this
  63. id does not exist, you must provide an``experiment_name``. This
  64. parameter takes precedence over ``experiment_name``.
  65. experiment_name: The name of an already existing MLflow
  66. experiment. All logs from all trials in ``tune.Tuner()`` will be
  67. reported to this experiment. If this is not provided, you must
  68. provide a valid ``experiment_id``.
  69. tracking_token: A token to use for HTTP authentication when
  70. logging to a remote tracking server. This is useful when you
  71. want to log to a Databricks server, for example. This value will
  72. be used to set the MLFLOW_TRACKING_TOKEN environment variable on
  73. all the remote training processes.
  74. artifact_location: The location to store run artifacts.
  75. If not provided, MLFlow picks an appropriate default.
  76. Ignored if experiment already exists.
  77. run_name: Name of the new MLflow run that will be created.
  78. If not set, will default to the ``experiment_name``.
  79. create_experiment_if_not_exists: Whether to create an
  80. experiment with the provided name if it does not already
  81. exist. Defaults to False.
  82. tags: Tags to set for the new run.
  83. rank_zero_only: If True, will return an initialized session only for the
  84. rank 0 worker in distributed training. If False, will initialize a
  85. session for all workers. Defaults to True.
  86. Example:
  87. Per default, you can just call ``setup_mlflow`` and continue to use
  88. MLflow like you would normally do:
  89. .. code-block:: python
  90. from ray.air.integrations.mlflow import setup_mlflow
  91. def training_loop(config):
  92. mlflow = setup_mlflow(config)
  93. # ...
  94. mlflow.log_metric(key="loss", val=0.123, step=0)
  95. In distributed data parallel training, you can utilize the return value of
  96. ``setup_mlflow``. This will make sure it is only invoked on the first worker
  97. in distributed training runs.
  98. .. code-block:: python
  99. from ray.air.integrations.mlflow import setup_mlflow
  100. def training_loop(config):
  101. mlflow = setup_mlflow(config)
  102. # ...
  103. mlflow.log_metric(key="loss", val=0.123, step=0)
  104. You can also use MlFlow's autologging feature if using a training
  105. framework like Pytorch Lightning, XGBoost, etc. More information can be
  106. found here
  107. (https://mlflow.org/docs/latest/tracking.html#automatic-logging).
  108. .. code-block:: python
  109. from ray.air.integrations.mlflow import setup_mlflow
  110. def train_fn(config):
  111. mlflow = setup_mlflow(config)
  112. mlflow.autolog()
  113. xgboost_results = xgb.train(config, ...)
  114. """
  115. if not mlflow:
  116. raise RuntimeError(
  117. "mlflow was not found - please install with `pip install mlflow`"
  118. )
  119. default_trial_id = None
  120. default_trial_name = None
  121. try:
  122. if _in_tune_session():
  123. context: ray.tune.TuneContext = ray.tune.get_context()
  124. default_trial_id = context.get_trial_id()
  125. default_trial_name = context.get_trial_name()
  126. else:
  127. context: ray.train.TrainContext = ray.train.get_context()
  128. if rank_zero_only and context.get_world_rank() != 0:
  129. return _NoopModule()
  130. except RuntimeError:
  131. default_trial_id = None
  132. default_trial_name = None
  133. _config = config.copy() if config else {}
  134. experiment_id = experiment_id or default_trial_id
  135. experiment_name = experiment_name or default_trial_name
  136. # Setup mlflow
  137. mlflow_util = _MLflowLoggerUtil()
  138. mlflow_util.setup_mlflow(
  139. tracking_uri=tracking_uri,
  140. registry_uri=registry_uri,
  141. experiment_id=experiment_id,
  142. experiment_name=experiment_name,
  143. tracking_token=tracking_token,
  144. artifact_location=artifact_location,
  145. create_experiment_if_not_exists=create_experiment_if_not_exists,
  146. )
  147. mlflow_util.start_run(
  148. run_name=run_name or experiment_name,
  149. tags=tags,
  150. set_active=True,
  151. )
  152. mlflow_util.log_params(_config)
  153. # Record `setup_mlflow` usage when everything has setup successfully.
  154. air_usage.tag_setup_mlflow()
  155. return mlflow_util._mlflow
  156. class MLflowLoggerCallback(LoggerCallback):
  157. """MLflow Logger to automatically log Tune results and config to MLflow.
  158. MLflow (https://mlflow.org) Tracking is an open source library for
  159. recording and querying experiments. This Ray Tune ``LoggerCallback``
  160. sends information (config parameters, training results & metrics,
  161. and artifacts) to MLflow for automatic experiment tracking.
  162. Keep in mind that the callback will open an MLflow session on the driver and
  163. not on the trainable. Therefore, it is not possible to call MLflow functions
  164. like ``mlflow.log_figure()`` inside the trainable as there is no MLflow session
  165. on the trainable. For more fine grained control, use
  166. :func:`ray.air.integrations.mlflow.setup_mlflow`.
  167. Args:
  168. tracking_uri: The tracking URI for where to manage experiments
  169. and runs. This can either be a local file path or a remote server.
  170. This arg gets passed directly to mlflow
  171. initialization. When using Tune in a multi-node setting, make sure
  172. to set this to a remote server and not a local file path.
  173. registry_uri: The registry URI that gets passed directly to
  174. mlflow initialization.
  175. experiment_name: The experiment name to use for this Tune run.
  176. If the experiment with the name already exists with MLflow,
  177. it will be reused. If not, a new experiment will be created with
  178. that name.
  179. tags: An optional dictionary of string keys and values to set
  180. as tags on the run
  181. tracking_token: Tracking token used to authenticate with MLflow.
  182. save_artifact: If set to True, automatically save the entire
  183. contents of the Tune local_dir as an artifact to the
  184. corresponding run in MlFlow.
  185. log_params_on_trial_end: If set to True, log parameters to MLflow
  186. at the end of the trial instead of at the beginning
  187. Example:
  188. .. code-block:: python
  189. from ray.air.integrations.mlflow import MLflowLoggerCallback
  190. tags = { "user_name" : "John",
  191. "git_commit_hash" : "abc123"}
  192. tune.run(
  193. train_fn,
  194. config={
  195. # define search space here
  196. "parameter_1": tune.choice([1, 2, 3]),
  197. "parameter_2": tune.choice([4, 5, 6]),
  198. },
  199. callbacks=[MLflowLoggerCallback(
  200. experiment_name="experiment1",
  201. tags=tags,
  202. save_artifact=True,
  203. log_params_on_trial_end=True)])
  204. """
  205. def __init__(
  206. self,
  207. tracking_uri: Optional[str] = None,
  208. *,
  209. registry_uri: Optional[str] = None,
  210. experiment_name: Optional[str] = None,
  211. tags: Optional[Dict] = None,
  212. tracking_token: Optional[str] = None,
  213. save_artifact: bool = False,
  214. log_params_on_trial_end: bool = False,
  215. ):
  216. self.tracking_uri = tracking_uri
  217. self.registry_uri = registry_uri
  218. self.experiment_name = experiment_name
  219. self.tags = tags
  220. self.tracking_token = tracking_token
  221. self.should_save_artifact = save_artifact
  222. self.log_params_on_trial_end = log_params_on_trial_end
  223. self.mlflow_util = _MLflowLoggerUtil()
  224. if ray.util.client.ray.is_connected():
  225. logger.warning(
  226. "When using MLflowLoggerCallback with Ray Client, "
  227. "it is recommended to use a remote tracking "
  228. "server. If you are using a MLflow tracking server "
  229. "backed by the local filesystem, then it must be "
  230. "setup on the server side and not on the client "
  231. "side."
  232. )
  233. def setup(self, *args, **kwargs):
  234. # Setup the mlflow logging util.
  235. self.mlflow_util.setup_mlflow(
  236. tracking_uri=self.tracking_uri,
  237. registry_uri=self.registry_uri,
  238. experiment_name=self.experiment_name,
  239. tracking_token=self.tracking_token,
  240. )
  241. if self.tags is None:
  242. # Create empty dictionary for tags if not given explicitly
  243. self.tags = {}
  244. self._trial_runs = {}
  245. def log_trial_start(self, trial: "Trial"):
  246. # Create run if not already exists.
  247. if trial not in self._trial_runs:
  248. # Set trial name in tags
  249. tags = self.tags.copy()
  250. tags["trial_name"] = str(trial)
  251. run = self.mlflow_util.start_run(tags=tags, run_name=str(trial))
  252. self._trial_runs[trial] = run.info.run_id
  253. run_id = self._trial_runs[trial]
  254. # Log the config parameters.
  255. config = trial.config
  256. if not self.log_params_on_trial_end:
  257. self.mlflow_util.log_params(run_id=run_id, params_to_log=config)
  258. def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
  259. step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
  260. run_id = self._trial_runs[trial]
  261. self.mlflow_util.log_metrics(run_id=run_id, metrics_to_log=result, step=step)
  262. def log_trial_end(self, trial: "Trial", failed: bool = False):
  263. run_id = self._trial_runs[trial]
  264. # Log the artifact if set_artifact is set to True.
  265. if self.should_save_artifact:
  266. self.mlflow_util.save_artifacts(run_id=run_id, dir=trial.local_path)
  267. # Stop the run once trial finishes.
  268. status = "FINISHED" if not failed else "FAILED"
  269. # Log the config parameters.
  270. config = trial.config
  271. if self.log_params_on_trial_end:
  272. self.mlflow_util.log_params(run_id=run_id, params_to_log=config)
  273. self.mlflow_util.end_run(run_id=run_id, status=status)