pbt_convnet_example.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #!/usr/bin/env python
  2. # ruff: noqa
  3. # fmt: off
  4. # __tutorial_imports_begin__
  5. import argparse
  6. import os
  7. import numpy as np
  8. import torch
  9. import torch.optim as optim
  10. from torchvision import datasets
  11. import ray
  12. from ray import tune
  13. from ray.tune.examples.mnist_pytorch import (
  14. ConvNet,
  15. get_data_loaders,
  16. test_func,
  17. train_func,
  18. )
  19. from ray.tune.schedulers import PopulationBasedTraining
  20. from ray.tune.utils import validate_save_restore
  21. # __tutorial_imports_end__
  22. # __trainable_begin__
  23. class PytorchTrainable(tune.Trainable):
  24. """Train a Pytorch ConvNet with Trainable and PopulationBasedTraining
  25. scheduler. The example reuse some of the functions in mnist_pytorch,
  26. and is a good demo for how to add the tuning function without
  27. changing the original training code.
  28. """
  29. def setup(self, config):
  30. self.train_loader, self.test_loader = get_data_loaders()
  31. self.model = ConvNet()
  32. self.optimizer = optim.SGD(
  33. self.model.parameters(),
  34. lr=config.get("lr", 0.01),
  35. momentum=config.get("momentum", 0.9))
  36. def step(self):
  37. train_func(self.model, self.optimizer, self.train_loader)
  38. acc = test_func(self.model, self.test_loader)
  39. return {"mean_accuracy": acc}
  40. def save_checkpoint(self, checkpoint_dir):
  41. checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
  42. torch.save(self.model.state_dict(), checkpoint_path)
  43. def load_checkpoint(self, checkpoint_dir):
  44. checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
  45. self.model.load_state_dict(torch.load(checkpoint_path))
  46. def reset_config(self, new_config):
  47. for param_group in self.optimizer.param_groups:
  48. if "lr" in new_config:
  49. param_group["lr"] = new_config["lr"]
  50. if "momentum" in new_config:
  51. param_group["momentum"] = new_config["momentum"]
  52. self.config = new_config
  53. return True
  54. # __trainable_end__
  55. if __name__ == "__main__":
  56. parser = argparse.ArgumentParser()
  57. parser.add_argument(
  58. "--smoke-test", action="store_true", help="Finish quickly for testing")
  59. args, _ = parser.parse_known_args()
  60. ray.init(num_cpus=2)
  61. datasets.MNIST("~/data", train=True, download=True)
  62. # check if PytorchTrainble will save/restore correctly before execution
  63. validate_save_restore(PytorchTrainable)
  64. # __pbt_begin__
  65. scheduler = PopulationBasedTraining(
  66. time_attr="training_iteration",
  67. perturbation_interval=5,
  68. hyperparam_mutations={
  69. # distribution for resampling
  70. "lr": lambda: np.random.uniform(0.0001, 1),
  71. # allow perturbations within this set of categorical values
  72. "momentum": [0.8, 0.9, 0.99],
  73. })
  74. # __pbt_end__
  75. # __tune_begin__
  76. class CustomStopper(tune.Stopper):
  77. def __init__(self):
  78. self.should_stop = False
  79. def __call__(self, trial_id, result):
  80. max_iter = 5 if args.smoke_test else 100
  81. if not self.should_stop and result["mean_accuracy"] > 0.96:
  82. self.should_stop = True
  83. return self.should_stop or result["training_iteration"] >= max_iter
  84. def stop_all(self):
  85. return self.should_stop
  86. stopper = CustomStopper()
  87. tuner = tune.Tuner(
  88. PytorchTrainable,
  89. run_config=tune.RunConfig(
  90. name="pbt_test",
  91. stop=stopper,
  92. verbose=1,
  93. checkpoint_config=tune.CheckpointConfig(
  94. checkpoint_score_attribute="mean_accuracy",
  95. checkpoint_frequency=5,
  96. num_to_keep=4,
  97. ),
  98. ),
  99. tune_config=tune.TuneConfig(
  100. scheduler=scheduler,
  101. metric="mean_accuracy",
  102. mode="max",
  103. num_samples=4,
  104. reuse_actors=True,
  105. ),
  106. param_space={
  107. "lr": tune.uniform(0.001, 1),
  108. "momentum": tune.uniform(0.001, 1),
  109. },
  110. )
  111. results = tuner.fit()
  112. # __tune_end__
  113. best_result = results.get_best_result()
  114. best_checkpoint = best_result.checkpoint