Getting high accuracy on CIFAR-10 is not straightforward. This self-contained script gets to 94% accuracy with a minimal setup.
| import argparse | |
| from tqdm import tqdm | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import models, datasets, transforms | |
| def get_CIFAR10(root="./"): | |
| input_size = 32 | |
| num_classes = 10 | |
| train_transform = transforms.Compose( | |
| [ | |
| transforms.RandomCrop(32, padding=4), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
| ] | |
| ) | |
| train_dataset = datasets.CIFAR10( | |
| root + "data/CIFAR10", train=True, transform=train_transform, download=True | |
| ) | |
| test_transform = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
| ] | |
| ) | |
| test_dataset = datasets.CIFAR10( | |
| root + "data/CIFAR10", train=False, transform=test_transform, download=True | |
| ) | |
| return input_size, num_classes, train_dataset, test_dataset | |
| class Model(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.resnet = models.resnet18(pretrained=False, num_classes=10) | |
| self.resnet.conv1 = torch.nn.Conv2d( | |
| 3, 64, kernel_size=3, stride=1, padding=1, bias=False | |
| ) | |
| self.resnet.maxpool = torch.nn.Identity() | |
| def forward(self, x): | |
| x = self.resnet(x) | |
| x = F.log_softmax(x, dim=1) | |
| return x | |
| def train(model, train_loader, optimizer, epoch): | |
| model.train() | |
| total_loss = [] | |
| for data, target in tqdm(train_loader): | |
| data = data.cuda() | |
| target = target.cuda() | |
| optimizer.zero_grad() | |
| prediction = model(data) | |
| loss = F.nll_loss(prediction, target) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss.append(loss.item()) | |
| avg_loss = sum(total_loss) / len(total_loss) | |
| print(f"Epoch: {epoch}:") | |
| print(f"Train Set: Average Loss: {avg_loss:.2f}") | |
| def test(model, test_loader): | |
| model.eval() | |
| loss = 0 | |
| correct = 0 | |
| for data, target in test_loader: | |
| with torch.no_grad(): | |
| data = data.cuda() | |
| target = target.cuda() | |
| prediction = model(data) | |
| loss += F.nll_loss(prediction, target, reduction="sum") | |
| prediction = prediction.max(1)[1] | |
| correct += prediction.eq(target.view_as(prediction)).sum().item() | |
| loss /= len(test_loader.dataset) | |
| percentage_correct = 100.0 * correct / len(test_loader.dataset) | |
| print( | |
| "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)".format( | |
| loss, correct, len(test_loader.dataset), percentage_correct | |
| ) | |
| ) | |
| return loss, percentage_correct | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--epochs", type=int, default=50, help="number of epochs to train (default: 50)" | |
| ) | |
| parser.add_argument( | |
| "--lr", type=float, default=0.05, help="learning rate (default: 0.05)" | |
| ) | |
| parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)") | |
| args = parser.parse_args() | |
| print(args) | |
| torch.manual_seed(args.seed) | |
| input_size, num_classes, train_dataset, test_dataset = get_CIFAR10() | |
| kwargs = {"num_workers": 2, "pin_memory": True} | |
| train_loader = torch.utils.data.DataLoader( | |
| train_dataset, batch_size=128, shuffle=True, **kwargs | |
| ) | |
| test_loader = torch.utils.data.DataLoader( | |
| test_dataset, batch_size=5000, shuffle=False, **kwargs | |
| ) | |
| model = Model() | |
| model = model.cuda() | |
| milestones = [25, 40] | |
| optimizer = torch.optim.SGD( | |
| model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4 | |
| ) | |
| scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, milestones=milestones, gamma=0.1 | |
| ) | |
| for epoch in range(1, args.epochs + 1): | |
| train(model, train_loader, optimizer, epoch) | |
| test(model, test_loader) | |
| scheduler.step() | |
| torch.save(model.state_dict(), "cifar_model.pt") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment