| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- import math
- import os
- import pytorch_lightning as pl
- import torch
- from datasets import load_dataset
- from filelock import FileLock
- from torch.nn import functional as F
- from torch.utils.data import DataLoader
- from torchmetrics import Accuracy
- from torchvision import transforms
- from ray import tune
- from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
- PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
- class MNISTDataModule(pl.LightningDataModule):
- def __init__(self, batch_size: int, data_dir: str = PATH_DATASETS):
- super().__init__()
- self.data_dir = data_dir
- self.transform = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,)),
- ]
- )
- self.batch_size = batch_size
- self.dims = (1, 28, 28)
- self.num_classes = 10
- def prepare_data(self):
- # download
- with FileLock(os.path.expanduser("~/.data.lock")):
- load_dataset("ylecun/mnist", cache_dir=self.data_dir)
- def setup(self, stage=None):
- dataset = load_dataset("ylecun/mnist", cache_dir=self.data_dir)
- def transform_fn(sample):
- return (self.transform(sample["image"]), sample["label"])
- self.mnist_train = [transform_fn(sample) for sample in dataset["train"]]
- self.mnist_val = [transform_fn(sample) for sample in dataset["test"]]
- def train_dataloader(self):
- return DataLoader(self.mnist_train, batch_size=self.batch_size)
- def val_dataloader(self):
- return DataLoader(self.mnist_val, batch_size=self.batch_size)
- class LightningMNISTClassifier(pl.LightningModule):
- def __init__(self, config, data_dir=None):
- super(LightningMNISTClassifier, self).__init__()
- self.data_dir = data_dir or os.getcwd()
- self.lr = config["lr"]
- layer_1, layer_2 = config["layer_1"], config["layer_2"]
- self.batch_size = config["batch_size"]
- # mnist images are (1, 28, 28) (channels, width, height)
- self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
- self.layer_2 = torch.nn.Linear(layer_1, layer_2)
- self.layer_3 = torch.nn.Linear(layer_2, 10)
- self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
- def forward(self, x):
- batch_size, channels, width, height = x.size()
- x = x.view(batch_size, -1)
- x = self.layer_1(x)
- x = torch.relu(x)
- x = self.layer_2(x)
- x = torch.relu(x)
- x = self.layer_3(x)
- x = torch.log_softmax(x, dim=1)
- return x
- def configure_optimizers(self):
- return torch.optim.Adam(self.parameters(), lr=self.lr)
- def training_step(self, train_batch, batch_idx):
- x, y = train_batch
- logits = self.forward(x)
- loss = F.nll_loss(logits, y)
- acc = self.accuracy(logits, y)
- self.log("ptl/train_loss", loss)
- self.log("ptl/train_accuracy", acc)
- return loss
- def validation_step(self, val_batch, batch_idx):
- x, y = val_batch
- logits = self.forward(x)
- loss = F.nll_loss(logits, y)
- acc = self.accuracy(logits, y)
- return {"val_loss": loss, "val_accuracy": acc}
- def validation_epoch_end(self, outputs):
- avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
- avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
- self.log("ptl/val_loss", avg_loss)
- self.log("ptl/val_accuracy", avg_acc)
- def train_mnist_tune(config, num_epochs=10, num_gpus=0):
- data_dir = os.path.abspath("./data")
- model = LightningMNISTClassifier(config, data_dir)
- with FileLock(os.path.expanduser("~/.data.lock")):
- dm = MNISTDataModule(data_dir=data_dir, batch_size=config["batch_size"])
- metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
- trainer = pl.Trainer(
- max_epochs=num_epochs,
- # If fractional GPUs passed in, convert to int.
- gpus=math.ceil(num_gpus),
- enable_progress_bar=False,
- callbacks=[
- TuneReportCheckpointCallback(
- metrics, on="validation_end", save_checkpoints=False
- )
- ],
- )
- trainer.fit(model, dm)
- def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
- config = {
- "layer_1": tune.choice([32, 64, 128]),
- "layer_2": tune.choice([64, 128, 256]),
- "lr": tune.loguniform(1e-4, 1e-1),
- "batch_size": tune.choice([32, 64, 128]),
- }
- trainable = tune.with_parameters(
- train_mnist_tune, num_epochs=num_epochs, num_gpus=gpus_per_trial
- )
- tuner = tune.Tuner(
- tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}),
- tune_config=tune.TuneConfig(
- metric="loss",
- mode="min",
- num_samples=num_samples,
- ),
- run_config=tune.RunConfig(
- name="tune_mnist",
- ),
- param_space=config,
- )
- results = tuner.fit()
- print("Best hyperparameters found were: ", results.get_best_result().config)
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--smoke-test", action="store_true", help="Finish quickly for testing"
- )
- args, _ = parser.parse_known_args()
- if args.smoke_test:
- tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0)
- else:
- tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)
|