mlflow_example.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #!/usr/bin/env python
  2. """Examples using MLfowLoggerCallback and setup_mlflow.
  3. """
  4. import os
  5. import tempfile
  6. import time
  7. import mlflow
  8. from ray import tune
  9. from ray.air.integrations.mlflow import MLflowLoggerCallback, setup_mlflow
  10. def evaluation_fn(step, width, height):
  11. return (0.1 + width * step / 100) ** (-1) + height * 0.1
  12. def train_function(config):
  13. # Hyperparameters
  14. width, height = config["width"], config["height"]
  15. for step in range(config.get("steps", 100)):
  16. # Iterative training function - can be any arbitrary training procedure
  17. intermediate_score = evaluation_fn(step, width, height)
  18. # Feed the score back to Tune.
  19. tune.report({"iterations": step, "mean_loss": intermediate_score})
  20. time.sleep(0.1)
  21. def tune_with_callback(mlflow_tracking_uri, finish_fast=False):
  22. tuner = tune.Tuner(
  23. train_function,
  24. run_config=tune.RunConfig(
  25. name="mlflow",
  26. callbacks=[
  27. MLflowLoggerCallback(
  28. tracking_uri=mlflow_tracking_uri,
  29. experiment_name="example",
  30. save_artifact=True,
  31. )
  32. ],
  33. ),
  34. tune_config=tune.TuneConfig(
  35. num_samples=5,
  36. ),
  37. param_space={
  38. "width": tune.randint(10, 100),
  39. "height": tune.randint(0, 100),
  40. "steps": 5 if finish_fast else 100,
  41. },
  42. )
  43. tuner.fit()
  44. def train_function_mlflow(config):
  45. setup_mlflow(config)
  46. # Hyperparameters
  47. width, height = config["width"], config["height"]
  48. for step in range(config.get("steps", 100)):
  49. # Iterative training function - can be any arbitrary training procedure
  50. intermediate_score = evaluation_fn(step, width, height)
  51. # Log the metrics to mlflow
  52. mlflow.log_metrics(dict(mean_loss=intermediate_score), step=step)
  53. # Feed the score back to Tune.
  54. tune.report({"iterations": step, "mean_loss": intermediate_score})
  55. time.sleep(0.1)
  56. def tune_with_setup(mlflow_tracking_uri, finish_fast=False):
  57. # Set the experiment, or create a new one if does not exist yet.
  58. mlflow.set_tracking_uri(mlflow_tracking_uri)
  59. mlflow.set_experiment(experiment_name="mixin_example")
  60. tuner = tune.Tuner(
  61. train_function_mlflow,
  62. run_config=tune.RunConfig(
  63. name="mlflow",
  64. ),
  65. tune_config=tune.TuneConfig(
  66. num_samples=5,
  67. ),
  68. param_space={
  69. "width": tune.randint(10, 100),
  70. "height": tune.randint(0, 100),
  71. "steps": 5 if finish_fast else 100,
  72. "mlflow": {
  73. "experiment_name": "mixin_example",
  74. "tracking_uri": mlflow.get_tracking_uri(),
  75. },
  76. },
  77. )
  78. tuner.fit()
  79. if __name__ == "__main__":
  80. import argparse
  81. parser = argparse.ArgumentParser()
  82. parser.add_argument(
  83. "--smoke-test", action="store_true", help="Finish quickly for testing"
  84. )
  85. parser.add_argument(
  86. "--tracking-uri",
  87. type=str,
  88. help="The tracking URI for the MLflow tracking server.",
  89. )
  90. args, _ = parser.parse_known_args()
  91. if args.smoke_test:
  92. mlflow_tracking_uri = os.path.join(tempfile.gettempdir(), "mlruns")
  93. else:
  94. mlflow_tracking_uri = args.tracking_uri
  95. tune_with_callback(mlflow_tracking_uri, finish_fast=args.smoke_test)
  96. if not args.smoke_test:
  97. df = mlflow.search_runs(
  98. [mlflow.get_experiment_by_name("example").experiment_id]
  99. )
  100. print(df)
  101. tune_with_setup(mlflow_tracking_uri, finish_fast=args.smoke_test)
  102. if not args.smoke_test:
  103. df = mlflow.search_runs(
  104. [mlflow.get_experiment_by_name("mixin_example").experiment_id]
  105. )
  106. print(df)