# ruff: noqa # fmt: off # __import_begin__ import os import tempfile from typing import Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.transforms as transforms from filelock import FileLock from torch.utils.data import random_split import ray from ray import tune from ray.tune import Checkpoint from ray.tune.schedulers import ASHAScheduler # __import_end__ # __load_data_begin__ DATA_DIR = tempfile.mkdtemp() def load_data(data_dir): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) # We add FileLock here because multiple workers will want to # download data, and this may cause overwrites since # DataLoader is not threadsafe. with FileLock(os.path.expanduser("~/.data.lock")): trainset = torchvision.datasets.CIFAR10( root=data_dir, train=True, download=True, transform=transform) testset = torchvision.datasets.CIFAR10( root=data_dir, train=False, download=True, transform=transform) return trainset, testset # __load_data_end__ def load_test_data(): # Loads a fake dataset for testing so it doesn't rely on external download. trainset = torchvision.datasets.FakeData( 128, (3, 32, 32), num_classes=10, transform=transforms.ToTensor() ) testset = torchvision.datasets.FakeData( 16, (3, 32, 32), num_classes=10, transform=transforms.ToTensor() ) return trainset, testset # __net_begin__ class Net(nn.Module): def __init__(self, l1=120, l2=84): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, l1) self.fc2 = nn.Linear(l1, l2) self.fc3 = nn.Linear(l2, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # __net_end__ # __train_begin__ def train_cifar(config): net = Net(config["l1"], config["l2"]) device = "cpu" if torch.cuda.is_available(): device = "cuda:0" if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9) # Load existing checkpoint through `get_checkpoint()` API. if tune.get_checkpoint(): loaded_checkpoint = tune.get_checkpoint() with loaded_checkpoint.as_directory() as loaded_checkpoint_dir: model_state, optimizer_state = torch.load( os.path.join(loaded_checkpoint_dir, "checkpoint.pt") ) net.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) if config["smoke_test"]: trainset, testset = load_test_data() else: trainset, testset = load_data(DATA_DIR) test_abs = int(len(trainset) * 0.8) train_subset, val_subset = random_split( trainset, [test_abs, len(trainset) - test_abs]) trainloader = torch.utils.data.DataLoader( train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=0 if config["smoke_test"] else 8, ) valloader = torch.utils.data.DataLoader( val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=0 if config["smoke_test"] else 8, ) for epoch in range(10): # loop over the dataset multiple times running_loss = 0.0 epoch_steps = 0 for i, data in enumerate(trainloader): # get the inputs; data is a list of [inputs, labels] inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() epoch_steps += 1 if i % 2000 == 1999: # print every 2000 mini-batches print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / epoch_steps)) running_loss = 0.0 # Validation loss val_loss = 0.0 val_steps = 0 total = 0 correct = 0 for i, data in enumerate(valloader, 0): with torch.no_grad(): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = net(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss = criterion(outputs, labels) val_loss += loss.cpu().numpy() val_steps += 1 # Here we save a checkpoint. It is automatically registered with # Ray Tune and will potentially be accessed through in ``get_checkpoint()`` # in future iterations. # Note to save a file like checkpoint, you still need to put it under a directory # to construct a checkpoint. with tempfile.TemporaryDirectory() as temp_checkpoint_dir: path = os.path.join(temp_checkpoint_dir, "checkpoint.pt") torch.save( (net.state_dict(), optimizer.state_dict()), path ) checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) tune.report( {"loss": (val_loss / val_steps), "accuracy": correct / total}, checkpoint=checkpoint, ) print("Finished Training") # __train_end__ # __test_acc_begin__ def test_best_model(config: Dict, checkpoint: "Checkpoint", smoke_test=False): best_trained_model = Net(config["l1"], config["l2"]) device = "cuda:0" if torch.cuda.is_available() else "cpu" best_trained_model.to(device) with checkpoint.as_directory() as checkpoint_dir: checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt") model_state, optimizer_state = torch.load(checkpoint_path) best_trained_model.load_state_dict(model_state) if smoke_test: _, testset = load_test_data() else: _, testset = load_data(DATA_DIR) testloader = torch.utils.data.DataLoader( testset, batch_size=4, shuffle=False, num_workers=2) correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = best_trained_model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print("Best trial test set accuracy: {}".format(correct / total)) # __test_acc_end__ # __main_begin__ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2, smoke_test=False): config = { "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), "lr": tune.loguniform(1e-4, 1e-1), "batch_size": tune.choice([2, 4, 8, 16]), "smoke_test": smoke_test, } scheduler = ASHAScheduler( max_t=max_num_epochs, grace_period=1, reduction_factor=2) tuner = tune.Tuner( tune.with_resources( tune.with_parameters(train_cifar), resources={"cpu": 2, "gpu": gpus_per_trial}, ), tune_config=tune.TuneConfig( metric="loss", mode="min", num_samples=num_samples, scheduler=scheduler ), param_space=config, ) results = tuner.fit() best_result = results.get_best_result("loss", "min") print("Best trial config: {}".format(best_result.config)) print("Best trial final validation loss: {}".format( best_result.metrics["loss"])) print("Best trial final validation accuracy: {}".format( best_result.metrics["accuracy"])) test_best_model(best_result.config, best_result.checkpoint, smoke_test=smoke_test) # __main_end__ if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing") parser.add_argument( "--ray-address", help="Address of Ray cluster for seamless distributed execution.", required=False) args, _ = parser.parse_known_args() if args.smoke_test: ray.init(num_cpus=2) main(num_samples=1, max_num_epochs=1, gpus_per_trial=0, smoke_test=True) else: ray.init(args.ray_address) # Change this to activate training on GPUs main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)