Skip to content

Instantly share code, notes, and snippets.

@ResidentMario
Created March 28, 2021 19:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ResidentMario/e3254172b4706191089bb63ecd610e21 to your computer and use it in GitHub Desktop.
Save ResidentMario/e3254172b4706191089bb63ecd610e21 to your computer and use it in GitHub Desktop.
CIFAR10Model(
(cnn_block_1): Sequential(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Dropout(p=0.25, inplace=False)
)
(cnn_block_2): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Dropout(p=0.25, inplace=False)
)
(head): Sequential(
(0): Linear(in_features=4096, out_features=512, bias=True)
(1): ReLU()
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=512, out_features=10, bias=True)
)
)
Files already downloaded and verified
Files already downloaded and verified
Finished epoch 1/10, batch 0. loss: 2.307.
Finished epoch 1/10, batch 200. loss: 1.721.
Finished epoch 1/10, batch 400. loss: 1.441.
Finished epoch 1/10, batch 600. loss: 1.716.
Finished epoch 1/10, batch 800. loss: 1.410.
Finished epoch 1/10, batch 1000. loss: 1.575.
Finished epoch 1/10, batch 1200. loss: 1.527.
Finished epoch 1/10, batch 1400. loss: 1.518.
Finished epoch 1. avg loss: 1.619261189561125; median loss: 1.5913920402526855
Finished epoch 2/10, batch 0. loss: 1.433.
Finished epoch 2/10, batch 200. loss: 1.302.
Finished epoch 2/10, batch 400. loss: 1.135.
Finished epoch 2/10, batch 600. loss: 1.482.
Finished epoch 2/10, batch 800. loss: 1.205.
Finished epoch 2/10, batch 1000. loss: 1.361.
Finished epoch 2/10, batch 1200. loss: 1.324.
Finished epoch 2/10, batch 1400. loss: 1.385.
Finished epoch 2. avg loss: 1.329154274094509; median loss: 1.3225899934768677
Finished epoch 3/10, batch 0. loss: 1.154.
Finished epoch 3/10, batch 200. loss: 1.167.
Finished epoch 3/10, batch 400. loss: 1.076.
Finished epoch 3/10, batch 600. loss: 1.430.
Finished epoch 3/10, batch 800. loss: 1.041.
Finished epoch 3/10, batch 1000. loss: 1.185.
Finished epoch 3/10, batch 1200. loss: 1.083.
Finished epoch 3/10, batch 1400. loss: 1.271.
Finished epoch 3. avg loss: 1.1742576085750827; median loss: 1.167173147201538
Finished epoch 4/10, batch 0. loss: 1.043.
Finished epoch 4/10, batch 200. loss: 1.046.
Finished epoch 4/10, batch 400. loss: 0.885.
Finished epoch 4/10, batch 600. loss: 1.199.
Finished epoch 4/10, batch 800. loss: 0.838.
Finished epoch 4/10, batch 1000. loss: 1.111.
Finished epoch 4/10, batch 1200. loss: 0.980.
Finished epoch 4/10, batch 1400. loss: 1.136.
Finished epoch 4. avg loss: 1.0620599480409005; median loss: 1.0530399084091187
Finished epoch 5/10, batch 0. loss: 0.962.
Finished epoch 5/10, batch 200. loss: 0.881.
Finished epoch 5/10, batch 400. loss: 0.799.
Finished epoch 5/10, batch 600. loss: 1.058.
Finished epoch 5/10, batch 800. loss: 0.829.
Finished epoch 5/10, batch 1000. loss: 1.017.
Finished epoch 5/10, batch 1200. loss: 0.806.
Finished epoch 5/10, batch 1400. loss: 1.007.
Finished epoch 5. avg loss: 0.9728337521363891; median loss: 0.9638816714286804
Finished epoch 6/10, batch 0. loss: 0.798.
Finished epoch 6/10, batch 200. loss: 0.949.
Finished epoch 6/10, batch 400. loss: 0.780.
Finished epoch 6/10, batch 600. loss: 0.988.
Finished epoch 6/10, batch 800. loss: 0.797.
Finished epoch 6/10, batch 1000. loss: 0.872.
Finished epoch 6/10, batch 1200. loss: 0.800.
Finished epoch 6/10, batch 1400. loss: 0.978.
Finished epoch 6. avg loss: 0.910540230405384; median loss: 0.9099277257919312
Finished epoch 7/10, batch 0. loss: 0.767.
Finished epoch 7/10, batch 200. loss: 0.765.
Finished epoch 7/10, batch 400. loss: 0.716.
Finished epoch 7/10, batch 600. loss: 0.951.
Finished epoch 7/10, batch 800. loss: 0.716.
Finished epoch 7/10, batch 1000. loss: 0.943.
Finished epoch 7/10, batch 1200. loss: 0.693.
Finished epoch 7/10, batch 1400. loss: 0.892.
Finished epoch 7. avg loss: 0.8551562123777617; median loss: 0.8441058397293091
Finished epoch 8/10, batch 0. loss: 0.653.
Finished epoch 8/10, batch 200. loss: 0.876.
Finished epoch 8/10, batch 400. loss: 0.705.
Finished epoch 8/10, batch 600. loss: 1.054.
Finished epoch 8/10, batch 800. loss: 0.562.
Finished epoch 8/10, batch 1000. loss: 0.766.
Finished epoch 8/10, batch 1200. loss: 0.675.
Finished epoch 8/10, batch 1400. loss: 0.923.
Finished epoch 8. avg loss: 0.8119357479968593; median loss: 0.8003608584403992
Finished epoch 9/10, batch 0. loss: 0.605.
Finished epoch 9/10, batch 200. loss: 0.813.
Finished epoch 9/10, batch 400. loss: 0.664.
Finished epoch 9/10, batch 600. loss: 0.917.
Finished epoch 9/10, batch 800. loss: 0.447.
Finished epoch 9/10, batch 1000. loss: 0.712.
Finished epoch 9/10, batch 1200. loss: 0.668.
Finished epoch 9/10, batch 1400. loss: 0.975.
Finished epoch 9. avg loss: 0.7708600752451293; median loss: 0.759121298789978
Finished epoch 10/10, batch 0. loss: 0.466.
Finished epoch 10/10, batch 200. loss: 0.718.
Finished epoch 10/10, batch 400. loss: 0.699.
Finished epoch 10/10, batch 600. loss: 0.953.
Finished epoch 10/10, batch 800. loss: 0.606.
Finished epoch 10/10, batch 1000. loss: 0.756.
Finished epoch 10/10, batch 1200. loss: 0.760.
Finished epoch 10/10, batch 1400. loss: 0.961.
Finished epoch 10. avg loss: 0.7295191306871096; median loss: 0.718676745891571
CPU times: user 3min 12s, sys: 744 ms, total: 3min 13s
Wall time: 3min 13s
import torchvision
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os
transform_train = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = torchvision.datasets.CIFAR10("/mnt/cifar10/", train=True, transform=transform_train, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
test_dataset = torchvision.datasets.CIFAR10("/mnt/cifar10/", train=False, transform=transform_test, download=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
class CIFAR10Model(nn.Module):
def __init__(self):
super().__init__()
self.cnn_block_1 = nn.Sequential(*[
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.25)
])
self.cnn_block_2 = nn.Sequential(*[
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.25)
])
self.flatten = lambda inp: torch.flatten(inp, 1)
self.head = nn.Sequential(*[
nn.Linear(64 * 8 * 8, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 10)
])
def forward(self, X):
X = self.cnn_block_1(X)
X = self.cnn_block_2(X)
X = self.flatten(X)
X = self.head(X)
return X
clf = CIFAR10Model()
start_epoch = 1
clf.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(clf.parameters(), lr=0.0001, weight_decay=1e-6)
def train():
clf.train()
NUM_EPOCHS = 10
for epoch in range(start_epoch, NUM_EPOCHS + 1):
losses = []
for i, (X_batch, y_cls) in enumerate(train_dataloader):
optimizer.zero_grad()
y = y_cls.cuda()
X_batch = X_batch.cuda()
y_pred = clf(X_batch)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
train_loss = loss.item()
if i % 200 == 0:
print(
f'Finished epoch {epoch}/{NUM_EPOCHS}, batch {i}. loss: {train_loss:.3f}.'
)
losses.append(train_loss)
print(
f'Finished epoch {epoch}. '
f'avg loss: {np.mean(losses)}; median loss: {np.median(losses)}'
)
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment