| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- # Original Code here:
- # https://github.com/pytorch/examples/blob/master/mnist/main.py
- import argparse
- import os
- import tempfile
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- from filelock import FileLock
- from torchvision import datasets, transforms
- import ray
- from ray import tune
- from ray.tune import Checkpoint
- from ray.tune.schedulers import AsyncHyperBandScheduler
- # Change these values if you want the training to run quicker or slower.
- EPOCH_SIZE = 512
- TEST_SIZE = 256
- class ConvNet(nn.Module):
- def __init__(self):
- super(ConvNet, self).__init__()
- self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
- self.fc = nn.Linear(192, 10)
- def forward(self, x):
- x = F.relu(F.max_pool2d(self.conv1(x), 3))
- x = x.view(-1, 192)
- x = self.fc(x)
- return F.log_softmax(x, dim=1)
- def train_func(model, optimizer, train_loader, device=None):
- device = device or torch.device("cpu")
- model.train()
- for batch_idx, (data, target) in enumerate(train_loader):
- if batch_idx * len(data) > EPOCH_SIZE:
- return
- data, target = data.to(device), target.to(device)
- optimizer.zero_grad()
- output = model(data)
- loss = F.nll_loss(output, target)
- loss.backward()
- optimizer.step()
- def test_func(model, data_loader, device=None):
- device = device or torch.device("cpu")
- model.eval()
- correct = 0
- total = 0
- with torch.no_grad():
- for batch_idx, (data, target) in enumerate(data_loader):
- if batch_idx * len(data) > TEST_SIZE:
- break
- data, target = data.to(device), target.to(device)
- outputs = model(data)
- _, predicted = torch.max(outputs.data, 1)
- total += target.size(0)
- correct += (predicted == target).sum().item()
- return correct / total
- def get_data_loaders(batch_size=64):
- mnist_transforms = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
- )
- # We add FileLock here because multiple workers will want to
- # download data, and this may cause overwrites since
- # DataLoader is not threadsafe.
- with FileLock(os.path.expanduser("~/data.lock")):
- train_loader = torch.utils.data.DataLoader(
- datasets.MNIST(
- "~/data", train=True, download=True, transform=mnist_transforms
- ),
- batch_size=batch_size,
- shuffle=True,
- )
- test_loader = torch.utils.data.DataLoader(
- datasets.MNIST(
- "~/data", train=False, download=True, transform=mnist_transforms
- ),
- batch_size=batch_size,
- shuffle=True,
- )
- return train_loader, test_loader
- def train_mnist(config):
- should_checkpoint = config.get("should_checkpoint", False)
- use_cuda = torch.cuda.is_available()
- device = torch.device("cuda" if use_cuda else "cpu")
- train_loader, test_loader = get_data_loaders()
- model = ConvNet().to(device)
- optimizer = optim.SGD(
- model.parameters(), lr=config["lr"], momentum=config["momentum"]
- )
- while True:
- train_func(model, optimizer, train_loader, device)
- acc = test_func(model, test_loader, device)
- metrics = {"mean_accuracy": acc}
- # Report metrics (and possibly a checkpoint)
- if should_checkpoint:
- with tempfile.TemporaryDirectory() as tempdir:
- torch.save(model.state_dict(), os.path.join(tempdir, "model.pt"))
- tune.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
- else:
- tune.report(metrics)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
- parser.add_argument(
- "--cuda", action="store_true", default=False, help="Enables GPU training"
- )
- parser.add_argument(
- "--smoke-test", action="store_true", help="Finish quickly for testing"
- )
- args, _ = parser.parse_known_args()
- ray.init(num_cpus=2 if args.smoke_test else None)
- # for early stopping
- sched = AsyncHyperBandScheduler()
- resources_per_trial = {"cpu": 2, "gpu": int(args.cuda)} # set this for GPUs
- tuner = tune.Tuner(
- tune.with_resources(train_mnist, resources=resources_per_trial),
- tune_config=tune.TuneConfig(
- metric="mean_accuracy",
- mode="max",
- scheduler=sched,
- num_samples=1 if args.smoke_test else 50,
- ),
- run_config=tune.RunConfig(
- name="exp",
- stop={
- "mean_accuracy": 0.98,
- "training_iteration": 5 if args.smoke_test else 100,
- },
- ),
- param_space={
- "lr": tune.loguniform(1e-4, 1e-2),
- "momentum": tune.uniform(0.1, 0.9),
- },
- )
- results = tuner.fit()
- print("Best config is:", results.get_best_result().config)
- assert not results.errors
|