mnist_pytorch.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Original Code here:
  2. # https://github.com/pytorch/examples/blob/master/mnist/main.py
  3. import argparse
  4. import os
  5. import tempfile
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torch.optim as optim
  10. from filelock import FileLock
  11. from torchvision import datasets, transforms
  12. import ray
  13. from ray import tune
  14. from ray.tune import Checkpoint
  15. from ray.tune.schedulers import AsyncHyperBandScheduler
  16. # Change these values if you want the training to run quicker or slower.
  17. EPOCH_SIZE = 512
  18. TEST_SIZE = 256
  19. class ConvNet(nn.Module):
  20. def __init__(self):
  21. super(ConvNet, self).__init__()
  22. self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
  23. self.fc = nn.Linear(192, 10)
  24. def forward(self, x):
  25. x = F.relu(F.max_pool2d(self.conv1(x), 3))
  26. x = x.view(-1, 192)
  27. x = self.fc(x)
  28. return F.log_softmax(x, dim=1)
  29. def train_func(model, optimizer, train_loader, device=None):
  30. device = device or torch.device("cpu")
  31. model.train()
  32. for batch_idx, (data, target) in enumerate(train_loader):
  33. if batch_idx * len(data) > EPOCH_SIZE:
  34. return
  35. data, target = data.to(device), target.to(device)
  36. optimizer.zero_grad()
  37. output = model(data)
  38. loss = F.nll_loss(output, target)
  39. loss.backward()
  40. optimizer.step()
  41. def test_func(model, data_loader, device=None):
  42. device = device or torch.device("cpu")
  43. model.eval()
  44. correct = 0
  45. total = 0
  46. with torch.no_grad():
  47. for batch_idx, (data, target) in enumerate(data_loader):
  48. if batch_idx * len(data) > TEST_SIZE:
  49. break
  50. data, target = data.to(device), target.to(device)
  51. outputs = model(data)
  52. _, predicted = torch.max(outputs.data, 1)
  53. total += target.size(0)
  54. correct += (predicted == target).sum().item()
  55. return correct / total
  56. def get_data_loaders(batch_size=64):
  57. mnist_transforms = transforms.Compose(
  58. [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
  59. )
  60. # We add FileLock here because multiple workers will want to
  61. # download data, and this may cause overwrites since
  62. # DataLoader is not threadsafe.
  63. with FileLock(os.path.expanduser("~/data.lock")):
  64. train_loader = torch.utils.data.DataLoader(
  65. datasets.MNIST(
  66. "~/data", train=True, download=True, transform=mnist_transforms
  67. ),
  68. batch_size=batch_size,
  69. shuffle=True,
  70. )
  71. test_loader = torch.utils.data.DataLoader(
  72. datasets.MNIST(
  73. "~/data", train=False, download=True, transform=mnist_transforms
  74. ),
  75. batch_size=batch_size,
  76. shuffle=True,
  77. )
  78. return train_loader, test_loader
  79. def train_mnist(config):
  80. should_checkpoint = config.get("should_checkpoint", False)
  81. use_cuda = torch.cuda.is_available()
  82. device = torch.device("cuda" if use_cuda else "cpu")
  83. train_loader, test_loader = get_data_loaders()
  84. model = ConvNet().to(device)
  85. optimizer = optim.SGD(
  86. model.parameters(), lr=config["lr"], momentum=config["momentum"]
  87. )
  88. while True:
  89. train_func(model, optimizer, train_loader, device)
  90. acc = test_func(model, test_loader, device)
  91. metrics = {"mean_accuracy": acc}
  92. # Report metrics (and possibly a checkpoint)
  93. if should_checkpoint:
  94. with tempfile.TemporaryDirectory() as tempdir:
  95. torch.save(model.state_dict(), os.path.join(tempdir, "model.pt"))
  96. tune.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
  97. else:
  98. tune.report(metrics)
  99. if __name__ == "__main__":
  100. parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
  101. parser.add_argument(
  102. "--cuda", action="store_true", default=False, help="Enables GPU training"
  103. )
  104. parser.add_argument(
  105. "--smoke-test", action="store_true", help="Finish quickly for testing"
  106. )
  107. args, _ = parser.parse_known_args()
  108. ray.init(num_cpus=2 if args.smoke_test else None)
  109. # for early stopping
  110. sched = AsyncHyperBandScheduler()
  111. resources_per_trial = {"cpu": 2, "gpu": int(args.cuda)} # set this for GPUs
  112. tuner = tune.Tuner(
  113. tune.with_resources(train_mnist, resources=resources_per_trial),
  114. tune_config=tune.TuneConfig(
  115. metric="mean_accuracy",
  116. mode="max",
  117. scheduler=sched,
  118. num_samples=1 if args.smoke_test else 50,
  119. ),
  120. run_config=tune.RunConfig(
  121. name="exp",
  122. stop={
  123. "mean_accuracy": 0.98,
  124. "training_iteration": 5 if args.smoke_test else 100,
  125. },
  126. ),
  127. param_space={
  128. "lr": tune.loguniform(1e-4, 1e-2),
  129. "momentum": tune.uniform(0.1, 0.9),
  130. },
  131. )
  132. results = tuner.fit()
  133. print("Best config is:", results.get_best_result().config)
  134. assert not results.errors