| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- # ruff: noqa
- # fmt: off
- # __import_begin__
- import os
- import tempfile
- from typing import Dict
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- import torchvision
- import torchvision.transforms as transforms
- from filelock import FileLock
- from torch.utils.data import random_split
- import ray
- from ray import tune
- from ray.tune import Checkpoint
- from ray.tune.schedulers import ASHAScheduler
- # __import_end__
- # __load_data_begin__
- DATA_DIR = tempfile.mkdtemp()
- def load_data(data_dir):
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
- ])
- # We add FileLock here because multiple workers will want to
- # download data, and this may cause overwrites since
- # DataLoader is not threadsafe.
- with FileLock(os.path.expanduser("~/.data.lock")):
- trainset = torchvision.datasets.CIFAR10(
- root=data_dir, train=True, download=True, transform=transform)
- testset = torchvision.datasets.CIFAR10(
- root=data_dir, train=False, download=True, transform=transform)
- return trainset, testset
- # __load_data_end__
- def load_test_data():
- # Loads a fake dataset for testing so it doesn't rely on external download.
- trainset = torchvision.datasets.FakeData(
- 128, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()
- )
- testset = torchvision.datasets.FakeData(
- 16, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()
- )
- return trainset, testset
- # __net_begin__
- class Net(nn.Module):
- def __init__(self, l1=120, l2=84):
- super(Net, self).__init__()
- self.conv1 = nn.Conv2d(3, 6, 5)
- self.pool = nn.MaxPool2d(2, 2)
- self.conv2 = nn.Conv2d(6, 16, 5)
- self.fc1 = nn.Linear(16 * 5 * 5, l1)
- self.fc2 = nn.Linear(l1, l2)
- self.fc3 = nn.Linear(l2, 10)
- def forward(self, x):
- x = self.pool(F.relu(self.conv1(x)))
- x = self.pool(F.relu(self.conv2(x)))
- x = x.view(-1, 16 * 5 * 5)
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = self.fc3(x)
- return x
- # __net_end__
- # __train_begin__
- def train_cifar(config):
- net = Net(config["l1"], config["l2"])
- device = "cpu"
- if torch.cuda.is_available():
- device = "cuda:0"
- if torch.cuda.device_count() > 1:
- net = nn.DataParallel(net)
- net.to(device)
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
- # Load existing checkpoint through `get_checkpoint()` API.
- if tune.get_checkpoint():
- loaded_checkpoint = tune.get_checkpoint()
- with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
- model_state, optimizer_state = torch.load(
- os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
- )
- net.load_state_dict(model_state)
- optimizer.load_state_dict(optimizer_state)
- if config["smoke_test"]:
- trainset, testset = load_test_data()
- else:
- trainset, testset = load_data(DATA_DIR)
- test_abs = int(len(trainset) * 0.8)
- train_subset, val_subset = random_split(
- trainset, [test_abs, len(trainset) - test_abs])
- trainloader = torch.utils.data.DataLoader(
- train_subset,
- batch_size=int(config["batch_size"]),
- shuffle=True,
- num_workers=0 if config["smoke_test"] else 8,
- )
- valloader = torch.utils.data.DataLoader(
- val_subset,
- batch_size=int(config["batch_size"]),
- shuffle=True,
- num_workers=0 if config["smoke_test"] else 8,
- )
- for epoch in range(10): # loop over the dataset multiple times
- running_loss = 0.0
- epoch_steps = 0
- for i, data in enumerate(trainloader):
- # get the inputs; data is a list of [inputs, labels]
- inputs, labels = data
- inputs, labels = inputs.to(device), labels.to(device)
- # zero the parameter gradients
- optimizer.zero_grad()
- # forward + backward + optimize
- outputs = net(inputs)
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- # print statistics
- running_loss += loss.item()
- epoch_steps += 1
- if i % 2000 == 1999: # print every 2000 mini-batches
- print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
- running_loss / epoch_steps))
- running_loss = 0.0
- # Validation loss
- val_loss = 0.0
- val_steps = 0
- total = 0
- correct = 0
- for i, data in enumerate(valloader, 0):
- with torch.no_grad():
- inputs, labels = data
- inputs, labels = inputs.to(device), labels.to(device)
- outputs = net(inputs)
- _, predicted = torch.max(outputs.data, 1)
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- loss = criterion(outputs, labels)
- val_loss += loss.cpu().numpy()
- val_steps += 1
- # Here we save a checkpoint. It is automatically registered with
- # Ray Tune and will potentially be accessed through in ``get_checkpoint()``
- # in future iterations.
- # Note to save a file like checkpoint, you still need to put it under a directory
- # to construct a checkpoint.
- with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
- path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
- torch.save(
- (net.state_dict(), optimizer.state_dict()), path
- )
- checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
- tune.report(
- {"loss": (val_loss / val_steps), "accuracy": correct / total},
- checkpoint=checkpoint,
- )
- print("Finished Training")
- # __train_end__
- # __test_acc_begin__
- def test_best_model(config: Dict, checkpoint: "Checkpoint", smoke_test=False):
- best_trained_model = Net(config["l1"], config["l2"])
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
- best_trained_model.to(device)
- with checkpoint.as_directory() as checkpoint_dir:
- checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt")
- model_state, optimizer_state = torch.load(checkpoint_path)
- best_trained_model.load_state_dict(model_state)
- if smoke_test:
- _, testset = load_test_data()
- else:
- _, testset = load_data(DATA_DIR)
- testloader = torch.utils.data.DataLoader(
- testset, batch_size=4, shuffle=False, num_workers=2)
- correct = 0
- total = 0
- with torch.no_grad():
- for data in testloader:
- images, labels = data
- images, labels = images.to(device), labels.to(device)
- outputs = best_trained_model(images)
- _, predicted = torch.max(outputs.data, 1)
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- print("Best trial test set accuracy: {}".format(correct / total))
- # __test_acc_end__
- # __main_begin__
- def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2, smoke_test=False):
- config = {
- "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
- "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
- "lr": tune.loguniform(1e-4, 1e-1),
- "batch_size": tune.choice([2, 4, 8, 16]),
- "smoke_test": smoke_test,
- }
- scheduler = ASHAScheduler(
- max_t=max_num_epochs,
- grace_period=1,
- reduction_factor=2)
- tuner = tune.Tuner(
- tune.with_resources(
- tune.with_parameters(train_cifar),
- resources={"cpu": 2, "gpu": gpus_per_trial},
- ),
- tune_config=tune.TuneConfig(
- metric="loss",
- mode="min",
- num_samples=num_samples,
- scheduler=scheduler
- ),
- param_space=config,
- )
- results = tuner.fit()
- best_result = results.get_best_result("loss", "min")
- print("Best trial config: {}".format(best_result.config))
- print("Best trial final validation loss: {}".format(
- best_result.metrics["loss"]))
- print("Best trial final validation accuracy: {}".format(
- best_result.metrics["accuracy"]))
- test_best_model(best_result.config, best_result.checkpoint, smoke_test=smoke_test)
- # __main_end__
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--smoke-test", action="store_true", help="Finish quickly for testing")
- parser.add_argument(
- "--ray-address",
- help="Address of Ray cluster for seamless distributed execution.",
- required=False)
- args, _ = parser.parse_known_args()
- if args.smoke_test:
- ray.init(num_cpus=2)
- main(num_samples=1, max_num_epochs=1, gpus_per_trial=0, smoke_test=True)
- else:
- ray.init(args.ray_address)
- # Change this to activate training on GPUs
- main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)
|