Created
February 1, 2022 04:24
-
-
Save Mahedi-61/0f3621f4241ebe9a692ee792e1dd36dc to your computer and use it in GitHub Desktop.
SqueezeNet on CIFAR-10 Dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.optim as optim | |
import numpy as np | |
from matplotlib import pyplot as plt | |
from tabnanny import check | |
from sklearn.metrics import confusion_matrix | |
import seaborn as sns | |
from torch import nn | |
import torch.nn.init as init | |
from torchvision import transforms, datasets | |
import os | |
lr = 0.001 | |
is_load = True | |
is_save = False | |
is_draw_conf = True | |
num_epochs = 200 | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
best_acc = 0 | |
start_epoch = 0 # start from epoch 0 or last checkpoint epoch | |
class Fire(nn.Module): | |
def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None: | |
super().__init__() | |
self.inplanes = inplanes | |
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) | |
self.squeeze_activation = nn.ReLU(inplace=True) | |
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1) | |
self.expand1x1_activation = nn.ReLU(inplace=True) | |
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1) | |
self.expand3x3_activation = nn.ReLU(inplace=True) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.squeeze_activation(self.squeeze(x)) | |
return torch.cat( | |
[self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1 | |
) | |
class SqueezeNet(nn.Module): | |
def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None: | |
super().__init__() | |
self.num_classes = num_classes | |
if version == "1_0": | |
self.features = nn.Sequential( | |
nn.Conv2d(3, 96, kernel_size=7, stride=2), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | |
Fire(96, 16, 64, 64), | |
Fire(128, 16, 64, 64), | |
Fire(128, 32, 128, 128), | |
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | |
Fire(256, 32, 128, 128), | |
Fire(256, 48, 192, 192), | |
Fire(384, 48, 192, 192), | |
Fire(384, 64, 256, 256), | |
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | |
Fire(512, 64, 256, 256), | |
) | |
elif version == "1_1": | |
self.features = nn.Sequential( | |
nn.Conv2d(3, 64, kernel_size=3, stride=2), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | |
Fire(64, 16, 64, 64), | |
Fire(128, 16, 64, 64), | |
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | |
Fire(128, 32, 128, 128), | |
Fire(256, 32, 128, 128), | |
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), | |
Fire(256, 48, 192, 192), | |
Fire(384, 48, 192, 192), | |
Fire(384, 64, 256, 256), | |
Fire(512, 64, 256, 256), | |
) | |
else: | |
# FIXME: Is this needed? SqueezeNet should only be called from the | |
# FIXME: squeezenet1_x() functions | |
# FIXME: This checking is not done for the other models | |
raise ValueError("Unsupported SqueezeNet version {%s}: 1_0 or 1_1 expected" %version) | |
# Final convolution is initialized differently from the rest | |
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) | |
self.classifier = nn.Sequential( | |
nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)) | |
) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
if m is final_conv: | |
init.normal_(m.weight, mean=0.0, std=0.01) | |
else: | |
init.kaiming_uniform_(m.weight) | |
if m.bias is not None: | |
init.constant_(m.bias, 0) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.features(x) | |
x = self.classifier(x) | |
return torch.flatten(x, 1) | |
class Train(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# Data | |
print('==> Preparing data..') | |
self.transform_train = transforms.Compose([ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
self.transform_test = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
self.trainloader = torch.utils.data.DataLoader( | |
datasets.CIFAR10(root='./data', train=True, download=True, transform=self.transform_train), | |
batch_size=256, | |
shuffle=True, | |
num_workers=6) | |
self.testloader = torch.utils.data.DataLoader( | |
datasets.CIFAR10(root='./data', train=False, download=True, transform=self.transform_test), | |
batch_size=256, | |
shuffle=False, | |
num_workers=6) | |
self.classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
'dog', 'frog', 'horse', 'ship', 'truck') | |
self.ls_train_acc = [] | |
self.ls_test_acc = [] | |
# weights download from https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth | |
print('==> Building model..') | |
self.model = SqueezeNet(version="1_1") | |
print("loading checkpoint from ImageNet") | |
checkpoint = torch.load("squeezenet1_1-b8a52dc0.pth") | |
self.model.load_state_dict(checkpoint) | |
# finetuning for CIFAR | |
self.model.classifier[1] = torch.nn.Conv2d(512, 10, kernel_size=(1, 1), stride=(1, 1)) | |
self.model = self.model.to(device) | |
if is_load: | |
# Load checkpoint. | |
print('==> Resuming from checkpoint..') | |
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' | |
checkpoint = torch.load('./checkpoint/ckpt.pth') | |
self.model.load_state_dict(checkpoint) | |
#self.model.load_state_dict(checkpoint['net']) | |
self.criterion = torch.nn.CrossEntropyLoss() | |
self.optimizer = optim.Adam(self.model.parameters(), | |
lr=lr, | |
betas=(0.9, 0.999), | |
eps=1e-08, | |
weight_decay=5e-4) | |
self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.90) | |
# Training | |
def train(self): | |
self.model.train() | |
for epoch in range(num_epochs): | |
epoch = epoch + 1 | |
print('\nEpoch | %d' % epoch) | |
train_loss = 0 | |
correct = 0 | |
total = 0 | |
for inputs, targets in self.trainloader: | |
inputs, targets = inputs.to(device), targets.to(device) | |
self.optimizer.zero_grad() | |
outputs = self.model(inputs) | |
loss = self.criterion(outputs, targets) | |
loss.backward() | |
self.optimizer.step() | |
train_loss += loss.item() | |
_, predicted = outputs.max(1) | |
total += targets.size(0) | |
correct += predicted.eq(targets).sum().item() | |
train_acc = 100.0 * (correct / total) | |
print('Train Loss: %.3f | Acc: %.3f%%' % (train_loss, train_acc)) | |
if (epoch % 5 == 0 and epoch != 0): | |
self.ls_train_acc.append(train_acc) | |
self.test(epoch) | |
if (epoch % 20 == 0 and epoch != 0): | |
self.scheduler.step() | |
def test(self, epoch): | |
global best_acc | |
self.model.eval() | |
test_loss = 0 | |
correct = 0 | |
total = 0 | |
true_label = [] | |
pred_label = [] | |
with torch.no_grad(): | |
for inputs, targets in self.testloader: | |
inputs, targets = inputs.to(device), targets.to(device) | |
outputs = self.model(inputs) | |
loss = self.criterion(outputs, targets) | |
test_loss += loss.item() | |
_, predicted = outputs.max(1) | |
total += targets.size(0) | |
correct += predicted.eq(targets).sum().item() | |
if is_draw_conf: | |
true_label += targets.cpu().detach().tolist() | |
pred_label += predicted.cpu().detach().tolist() | |
test_acc = 100.0 * (correct / total) | |
print('Test Loss: %.3f | Acc: %.3f%%' % (test_loss, test_acc)) | |
if (epoch % 5 == 0 and epoch != 0): | |
self.ls_test_acc.append(test_acc) | |
# Save checkpoint. | |
if is_save: | |
acc = 100.*correct/total | |
if acc > best_acc: | |
print('Saving..') | |
if not os.path.isdir('checkpoint'): | |
os.mkdir('checkpoint') | |
torch.save(self.model.state_dict(), './checkpoint/ckpt.pth') | |
best_acc = acc | |
self.model.train() | |
if is_draw_conf: return true_label, pred_label | |
def draw_confusion_matrix(self, true_labels, pred_labels): | |
# make confusion matrix | |
c_matrix = confusion_matrix(y_true=true_labels, y_pred=pred_labels) | |
c_matrix = [row / 10 for row in c_matrix] | |
label = np.asarray([[str(col) + "%" for col in row] for row in c_matrix ]) | |
sns. set(font_scale=1.4) | |
sns.heatmap(c_matrix, annot=label, fmt = '', linewidths=.5, xticklabels=self.classes, yticklabels=self.classes) | |
# labels, title | |
plt.xlabel('Predicted Label', fontsize=10, labelpad=11) | |
plt.ylabel('True Label', fontsize=10) | |
plt.show() | |
def plot_learning_curve(self, acc): | |
epochs = [i for i in range(5, 201, 5)] | |
plt.plot(epochs, acc, 'g', label='Test accuracy') | |
plt.title('Test accuracy on CIFAR dataset') | |
plt.xlabel('Number of Epochs') | |
plt.ylabel('Accuracy') | |
plt.legend() | |
plt.show() | |
if __name__ == "__main__": | |
t = Train() | |
ls_train_acc = (np.load("train_acc.npy")).tolist() | |
t.plot_learning_curve(ls_train_acc) | |
#t.train() | |
#np.save("train_acc.npy", t.ls_train_acc) | |
#np.save("test_acc.npy", t.ls_test_acc) | |
#true_label, pred_label = t.test(epoch = 0) | |
#t.draw_confusion_matrix(true_label, pred_label) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment