cifar10_pytorch.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # ruff: noqa
  2. # fmt: off
  3. # __import_begin__
  4. import os
  5. import tempfile
  6. from typing import Dict
  7. import numpy as np
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import torch.optim as optim
  12. import torchvision
  13. import torchvision.transforms as transforms
  14. from filelock import FileLock
  15. from torch.utils.data import random_split
  16. import ray
  17. from ray import tune
  18. from ray.tune import Checkpoint
  19. from ray.tune.schedulers import ASHAScheduler
  20. # __import_end__
  21. # __load_data_begin__
  22. DATA_DIR = tempfile.mkdtemp()
  23. def load_data(data_dir):
  24. transform = transforms.Compose([
  25. transforms.ToTensor(),
  26. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  27. ])
  28. # We add FileLock here because multiple workers will want to
  29. # download data, and this may cause overwrites since
  30. # DataLoader is not threadsafe.
  31. with FileLock(os.path.expanduser("~/.data.lock")):
  32. trainset = torchvision.datasets.CIFAR10(
  33. root=data_dir, train=True, download=True, transform=transform)
  34. testset = torchvision.datasets.CIFAR10(
  35. root=data_dir, train=False, download=True, transform=transform)
  36. return trainset, testset
  37. # __load_data_end__
  38. def load_test_data():
  39. # Loads a fake dataset for testing so it doesn't rely on external download.
  40. trainset = torchvision.datasets.FakeData(
  41. 128, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()
  42. )
  43. testset = torchvision.datasets.FakeData(
  44. 16, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()
  45. )
  46. return trainset, testset
  47. # __net_begin__
  48. class Net(nn.Module):
  49. def __init__(self, l1=120, l2=84):
  50. super(Net, self).__init__()
  51. self.conv1 = nn.Conv2d(3, 6, 5)
  52. self.pool = nn.MaxPool2d(2, 2)
  53. self.conv2 = nn.Conv2d(6, 16, 5)
  54. self.fc1 = nn.Linear(16 * 5 * 5, l1)
  55. self.fc2 = nn.Linear(l1, l2)
  56. self.fc3 = nn.Linear(l2, 10)
  57. def forward(self, x):
  58. x = self.pool(F.relu(self.conv1(x)))
  59. x = self.pool(F.relu(self.conv2(x)))
  60. x = x.view(-1, 16 * 5 * 5)
  61. x = F.relu(self.fc1(x))
  62. x = F.relu(self.fc2(x))
  63. x = self.fc3(x)
  64. return x
  65. # __net_end__
  66. # __train_begin__
  67. def train_cifar(config):
  68. net = Net(config["l1"], config["l2"])
  69. device = "cpu"
  70. if torch.cuda.is_available():
  71. device = "cuda:0"
  72. if torch.cuda.device_count() > 1:
  73. net = nn.DataParallel(net)
  74. net.to(device)
  75. criterion = nn.CrossEntropyLoss()
  76. optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
  77. # Load existing checkpoint through `get_checkpoint()` API.
  78. if tune.get_checkpoint():
  79. loaded_checkpoint = tune.get_checkpoint()
  80. with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
  81. model_state, optimizer_state = torch.load(
  82. os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
  83. )
  84. net.load_state_dict(model_state)
  85. optimizer.load_state_dict(optimizer_state)
  86. if config["smoke_test"]:
  87. trainset, testset = load_test_data()
  88. else:
  89. trainset, testset = load_data(DATA_DIR)
  90. test_abs = int(len(trainset) * 0.8)
  91. train_subset, val_subset = random_split(
  92. trainset, [test_abs, len(trainset) - test_abs])
  93. trainloader = torch.utils.data.DataLoader(
  94. train_subset,
  95. batch_size=int(config["batch_size"]),
  96. shuffle=True,
  97. num_workers=0 if config["smoke_test"] else 8,
  98. )
  99. valloader = torch.utils.data.DataLoader(
  100. val_subset,
  101. batch_size=int(config["batch_size"]),
  102. shuffle=True,
  103. num_workers=0 if config["smoke_test"] else 8,
  104. )
  105. for epoch in range(10): # loop over the dataset multiple times
  106. running_loss = 0.0
  107. epoch_steps = 0
  108. for i, data in enumerate(trainloader):
  109. # get the inputs; data is a list of [inputs, labels]
  110. inputs, labels = data
  111. inputs, labels = inputs.to(device), labels.to(device)
  112. # zero the parameter gradients
  113. optimizer.zero_grad()
  114. # forward + backward + optimize
  115. outputs = net(inputs)
  116. loss = criterion(outputs, labels)
  117. loss.backward()
  118. optimizer.step()
  119. # print statistics
  120. running_loss += loss.item()
  121. epoch_steps += 1
  122. if i % 2000 == 1999: # print every 2000 mini-batches
  123. print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
  124. running_loss / epoch_steps))
  125. running_loss = 0.0
  126. # Validation loss
  127. val_loss = 0.0
  128. val_steps = 0
  129. total = 0
  130. correct = 0
  131. for i, data in enumerate(valloader, 0):
  132. with torch.no_grad():
  133. inputs, labels = data
  134. inputs, labels = inputs.to(device), labels.to(device)
  135. outputs = net(inputs)
  136. _, predicted = torch.max(outputs.data, 1)
  137. total += labels.size(0)
  138. correct += (predicted == labels).sum().item()
  139. loss = criterion(outputs, labels)
  140. val_loss += loss.cpu().numpy()
  141. val_steps += 1
  142. # Here we save a checkpoint. It is automatically registered with
  143. # Ray Tune and will potentially be accessed through in ``get_checkpoint()``
  144. # in future iterations.
  145. # Note to save a file like checkpoint, you still need to put it under a directory
  146. # to construct a checkpoint.
  147. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  148. path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
  149. torch.save(
  150. (net.state_dict(), optimizer.state_dict()), path
  151. )
  152. checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
  153. tune.report(
  154. {"loss": (val_loss / val_steps), "accuracy": correct / total},
  155. checkpoint=checkpoint,
  156. )
  157. print("Finished Training")
  158. # __train_end__
  159. # __test_acc_begin__
  160. def test_best_model(config: Dict, checkpoint: "Checkpoint", smoke_test=False):
  161. best_trained_model = Net(config["l1"], config["l2"])
  162. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  163. best_trained_model.to(device)
  164. with checkpoint.as_directory() as checkpoint_dir:
  165. checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt")
  166. model_state, optimizer_state = torch.load(checkpoint_path)
  167. best_trained_model.load_state_dict(model_state)
  168. if smoke_test:
  169. _, testset = load_test_data()
  170. else:
  171. _, testset = load_data(DATA_DIR)
  172. testloader = torch.utils.data.DataLoader(
  173. testset, batch_size=4, shuffle=False, num_workers=2)
  174. correct = 0
  175. total = 0
  176. with torch.no_grad():
  177. for data in testloader:
  178. images, labels = data
  179. images, labels = images.to(device), labels.to(device)
  180. outputs = best_trained_model(images)
  181. _, predicted = torch.max(outputs.data, 1)
  182. total += labels.size(0)
  183. correct += (predicted == labels).sum().item()
  184. print("Best trial test set accuracy: {}".format(correct / total))
  185. # __test_acc_end__
  186. # __main_begin__
  187. def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2, smoke_test=False):
  188. config = {
  189. "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
  190. "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
  191. "lr": tune.loguniform(1e-4, 1e-1),
  192. "batch_size": tune.choice([2, 4, 8, 16]),
  193. "smoke_test": smoke_test,
  194. }
  195. scheduler = ASHAScheduler(
  196. max_t=max_num_epochs,
  197. grace_period=1,
  198. reduction_factor=2)
  199. tuner = tune.Tuner(
  200. tune.with_resources(
  201. tune.with_parameters(train_cifar),
  202. resources={"cpu": 2, "gpu": gpus_per_trial},
  203. ),
  204. tune_config=tune.TuneConfig(
  205. metric="loss",
  206. mode="min",
  207. num_samples=num_samples,
  208. scheduler=scheduler
  209. ),
  210. param_space=config,
  211. )
  212. results = tuner.fit()
  213. best_result = results.get_best_result("loss", "min")
  214. print("Best trial config: {}".format(best_result.config))
  215. print("Best trial final validation loss: {}".format(
  216. best_result.metrics["loss"]))
  217. print("Best trial final validation accuracy: {}".format(
  218. best_result.metrics["accuracy"]))
  219. test_best_model(best_result.config, best_result.checkpoint, smoke_test=smoke_test)
  220. # __main_end__
  221. if __name__ == "__main__":
  222. import argparse
  223. parser = argparse.ArgumentParser()
  224. parser.add_argument(
  225. "--smoke-test", action="store_true", help="Finish quickly for testing")
  226. parser.add_argument(
  227. "--ray-address",
  228. help="Address of Ray cluster for seamless distributed execution.",
  229. required=False)
  230. args, _ = parser.parse_known_args()
  231. if args.smoke_test:
  232. ray.init(num_cpus=2)
  233. main(num_samples=1, max_num_epochs=1, gpus_per_trial=0, smoke_test=True)
  234. else:
  235. ray.init(args.ray_address)
  236. # Change this to activate training on GPUs
  237. main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)