Skip to content

Instantly share code, notes, and snippets.

@Mahedi-61
Created February 1, 2022 04:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Mahedi-61/0f3621f4241ebe9a692ee792e1dd36dc to your computer and use it in GitHub Desktop.
Save Mahedi-61/0f3621f4241ebe9a692ee792e1dd36dc to your computer and use it in GitHub Desktop.
SqueezeNet on CIFAR-10 Dataset
import torch
import torch.optim as optim
import numpy as np
from matplotlib import pyplot as plt
from tabnanny import check
from sklearn.metrics import confusion_matrix
import seaborn as sns
from torch import nn
import torch.nn.init as init
from torchvision import transforms, datasets
import os
lr = 0.001
is_load = True
is_save = False
is_draw_conf = True
num_epochs = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
class Fire(nn.Module):
def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None:
super().__init__()
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
self.squeeze_activation = nn.ReLU(inplace=True)
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
self.expand1x1_activation = nn.ReLU(inplace=True)
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
self.expand3x3_activation = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.squeeze_activation(self.squeeze(x))
return torch.cat(
[self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1
)
class SqueezeNet(nn.Module):
def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None:
super().__init__()
self.num_classes = num_classes
if version == "1_0":
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(96, 16, 64, 64),
Fire(128, 16, 64, 64),
Fire(128, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 32, 128, 128),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(512, 64, 256, 256),
)
elif version == "1_1":
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(64, 16, 64, 64),
Fire(128, 16, 64, 64),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(128, 32, 128, 128),
Fire(256, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256),
)
else:
# FIXME: Is this needed? SqueezeNet should only be called from the
# FIXME: squeezenet1_x() functions
# FIXME: This checking is not done for the other models
raise ValueError("Unsupported SqueezeNet version {%s}: 1_0 or 1_1 expected" %version)
# Final convolution is initialized differently from the rest
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
self.classifier = nn.Sequential(
nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m is final_conv:
init.normal_(m.weight, mean=0.0, std=0.01)
else:
init.kaiming_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.classifier(x)
return torch.flatten(x, 1)
class Train(nn.Module):
def __init__(self):
super().__init__()
# Data
print('==> Preparing data..')
self.transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
self.transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
self.trainloader = torch.utils.data.DataLoader(
datasets.CIFAR10(root='./data', train=True, download=True, transform=self.transform_train),
batch_size=256,
shuffle=True,
num_workers=6)
self.testloader = torch.utils.data.DataLoader(
datasets.CIFAR10(root='./data', train=False, download=True, transform=self.transform_test),
batch_size=256,
shuffle=False,
num_workers=6)
self.classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
self.ls_train_acc = []
self.ls_test_acc = []
# weights download from https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth
print('==> Building model..')
self.model = SqueezeNet(version="1_1")
print("loading checkpoint from ImageNet")
checkpoint = torch.load("squeezenet1_1-b8a52dc0.pth")
self.model.load_state_dict(checkpoint)
# finetuning for CIFAR
self.model.classifier[1] = torch.nn.Conv2d(512, 10, kernel_size=(1, 1), stride=(1, 1))
self.model = self.model.to(device)
if is_load:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')
self.model.load_state_dict(checkpoint)
#self.model.load_state_dict(checkpoint['net'])
self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.model.parameters(),
lr=lr,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=5e-4)
self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.90)
# Training
def train(self):
self.model.train()
for epoch in range(num_epochs):
epoch = epoch + 1
print('\nEpoch | %d' % epoch)
train_loss = 0
correct = 0
total = 0
for inputs, targets in self.trainloader:
inputs, targets = inputs.to(device), targets.to(device)
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
loss.backward()
self.optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
train_acc = 100.0 * (correct / total)
print('Train Loss: %.3f | Acc: %.3f%%' % (train_loss, train_acc))
if (epoch % 5 == 0 and epoch != 0):
self.ls_train_acc.append(train_acc)
self.test(epoch)
if (epoch % 20 == 0 and epoch != 0):
self.scheduler.step()
def test(self, epoch):
global best_acc
self.model.eval()
test_loss = 0
correct = 0
total = 0
true_label = []
pred_label = []
with torch.no_grad():
for inputs, targets in self.testloader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if is_draw_conf:
true_label += targets.cpu().detach().tolist()
pred_label += predicted.cpu().detach().tolist()
test_acc = 100.0 * (correct / total)
print('Test Loss: %.3f | Acc: %.3f%%' % (test_loss, test_acc))
if (epoch % 5 == 0 and epoch != 0):
self.ls_test_acc.append(test_acc)
# Save checkpoint.
if is_save:
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(self.model.state_dict(), './checkpoint/ckpt.pth')
best_acc = acc
self.model.train()
if is_draw_conf: return true_label, pred_label
def draw_confusion_matrix(self, true_labels, pred_labels):
# make confusion matrix
c_matrix = confusion_matrix(y_true=true_labels, y_pred=pred_labels)
c_matrix = [row / 10 for row in c_matrix]
label = np.asarray([[str(col) + "%" for col in row] for row in c_matrix ])
sns. set(font_scale=1.4)
sns.heatmap(c_matrix, annot=label, fmt = '', linewidths=.5, xticklabels=self.classes, yticklabels=self.classes)
# labels, title
plt.xlabel('Predicted Label', fontsize=10, labelpad=11)
plt.ylabel('True Label', fontsize=10)
plt.show()
def plot_learning_curve(self, acc):
epochs = [i for i in range(5, 201, 5)]
plt.plot(epochs, acc, 'g', label='Test accuracy')
plt.title('Test accuracy on CIFAR dataset')
plt.xlabel('Number of Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
if __name__ == "__main__":
t = Train()
ls_train_acc = (np.load("train_acc.npy")).tolist()
t.plot_learning_curve(ls_train_acc)
#t.train()
#np.save("train_acc.npy", t.ls_train_acc)
#np.save("test_acc.npy", t.ls_test_acc)
#true_label, pred_label = t.test(epoch = 0)
#t.draw_confusion_matrix(true_label, pred_label)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment