-
-
Save ResidentMario/9c3a90504d1a027aab926fd65ae08139 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CIFAR10Model( | |
(cnn_block_1): Sequential( | |
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(1): ReLU() | |
(2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(3): ReLU() | |
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) | |
) | |
(dropout_1): Dropout(p=0.25, inplace=False) | |
(cnn_block_2): Sequential( | |
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(1): ReLU() | |
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | |
(3): ReLU() | |
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) | |
) | |
(dropout_2): Dropout(p=0.25, inplace=False) | |
(linearize): Sequential( | |
(0): Linear(in_features=4096, out_features=512, bias=True) | |
(1): ReLU() | |
) | |
(dropout_3): Dropout(p=0.5, inplace=False) | |
(out): Linear(in_features=512, out_features=10, bias=True) | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Files already downloaded and verified | |
Files already downloaded and verified | |
Finished epoch 1/10, batch 0. loss: 2.310. | |
Finished epoch 1/10, batch 200. loss: 2.168. | |
Finished epoch 1/10, batch 400. loss: 2.148. | |
Finished epoch 1/10, batch 600. loss: 2.215. | |
Finished epoch 1/10, batch 800. loss: 2.013. | |
Finished epoch 1/10, batch 1000. loss: 1.990. | |
Finished epoch 1/10, batch 1200. loss: 2.142. | |
Finished epoch 1/10, batch 1400. loss: 1.950. | |
Finished epoch 1. avg loss: 2.0989584789318836; median loss: 2.103419542312622 | |
Finished epoch 2/10, batch 0. loss: 1.955. | |
Finished epoch 2/10, batch 200. loss: 1.818. | |
Finished epoch 2/10, batch 400. loss: 1.939. | |
Finished epoch 2/10, batch 600. loss: 2.050. | |
Finished epoch 2/10, batch 800. loss: 1.904. | |
Finished epoch 2/10, batch 1000. loss: 1.781. | |
Finished epoch 2/10, batch 1200. loss: 2.044. | |
Finished epoch 2/10, batch 1400. loss: 1.856. | |
Finished epoch 2. avg loss: 1.9179350005001574; median loss: 1.9195481538772583 | |
Finished epoch 3/10, batch 0. loss: 1.829. | |
Finished epoch 3/10, batch 200. loss: 1.631. | |
Finished epoch 3/10, batch 400. loss: 1.774. | |
Finished epoch 3/10, batch 600. loss: 1.940. | |
Finished epoch 3/10, batch 800. loss: 1.789. | |
Finished epoch 3/10, batch 1000. loss: 1.765. | |
Finished epoch 3/10, batch 1200. loss: 2.025. | |
Finished epoch 3/10, batch 1400. loss: 1.819. | |
Finished epoch 3. avg loss: 1.8256724212540316; median loss: 1.8263983726501465 | |
Finished epoch 4/10, batch 0. loss: 1.787. | |
Finished epoch 4/10, batch 200. loss: 1.471. | |
Finished epoch 4/10, batch 400. loss: 1.513. | |
Finished epoch 4/10, batch 600. loss: 1.818. | |
Finished epoch 4/10, batch 800. loss: 1.772. | |
Finished epoch 4/10, batch 1000. loss: 1.769. | |
Finished epoch 4/10, batch 1200. loss: 1.889. | |
Finished epoch 4/10, batch 1400. loss: 1.760. | |
Finished epoch 4. avg loss: 1.7637811539192956; median loss: 1.7644797563552856 | |
Finished epoch 5/10, batch 0. loss: 1.759. | |
Finished epoch 5/10, batch 200. loss: 1.441. | |
Finished epoch 5/10, batch 400. loss: 1.579. | |
Finished epoch 5/10, batch 600. loss: 1.900. | |
Finished epoch 5/10, batch 800. loss: 1.574. | |
Finished epoch 5/10, batch 1000. loss: 1.718. | |
Finished epoch 5/10, batch 1200. loss: 1.986. | |
Finished epoch 5/10, batch 1400. loss: 1.735. | |
Finished epoch 5. avg loss: 1.7175132141461071; median loss: 1.7135170698165894 | |
Finished epoch 6/10, batch 0. loss: 1.683. | |
Finished epoch 6/10, batch 200. loss: 1.456. | |
Finished epoch 6/10, batch 400. loss: 1.518. | |
Finished epoch 6/10, batch 600. loss: 1.790. | |
Finished epoch 6/10, batch 800. loss: 1.610. | |
Finished epoch 6/10, batch 1000. loss: 1.584. | |
Finished epoch 6/10, batch 1200. loss: 2.033. | |
Finished epoch 6/10, batch 1400. loss: 1.560. | |
Finished epoch 6. avg loss: 1.6815758042814484; median loss: 1.6768428087234497 | |
Finished epoch 7/10, batch 0. loss: 1.638. | |
Finished epoch 7/10, batch 200. loss: 1.336. | |
Finished epoch 7/10, batch 400. loss: 1.347. | |
Finished epoch 7/10, batch 600. loss: 1.867. | |
Finished epoch 7/10, batch 800. loss: 1.584. | |
Finished epoch 7/10, batch 1000. loss: 1.650. | |
Finished epoch 7/10, batch 1200. loss: 1.761. | |
Finished epoch 7/10, batch 1400. loss: 1.666. | |
Finished epoch 7. avg loss: 1.6539903240789608; median loss: 1.6436301469802856 | |
Finished epoch 8/10, batch 0. loss: 1.646. | |
Finished epoch 8/10, batch 200. loss: 1.432. | |
Finished epoch 8/10, batch 400. loss: 1.353. | |
Finished epoch 8/10, batch 600. loss: 1.690. | |
Finished epoch 8/10, batch 800. loss: 1.673. | |
Finished epoch 8/10, batch 1000. loss: 1.671. | |
Finished epoch 8/10, batch 1200. loss: 1.860. | |
Finished epoch 8/10, batch 1400. loss: 1.578. | |
Finished epoch 8. avg loss: 1.6351439083377597; median loss: 1.6342540979385376 | |
Finished epoch 9/10, batch 0. loss: 1.734. | |
Finished epoch 9/10, batch 200. loss: 1.336. | |
Finished epoch 9/10, batch 400. loss: 1.311. | |
Finished epoch 9/10, batch 600. loss: 1.802. | |
Finished epoch 9/10, batch 800. loss: 1.446. | |
Finished epoch 9/10, batch 1000. loss: 1.589. | |
Finished epoch 9/10, batch 1200. loss: 1.731. | |
Finished epoch 9/10, batch 1400. loss: 1.603. | |
Finished epoch 9. avg loss: 1.6180183659435767; median loss: 1.6119329929351807 | |
Finished epoch 10/10, batch 0. loss: 1.711. | |
Finished epoch 10/10, batch 200. loss: 1.306. | |
Finished epoch 10/10, batch 400. loss: 1.369. | |
Finished epoch 10/10, batch 600. loss: 1.778. | |
Finished epoch 10/10, batch 800. loss: 1.599. | |
Finished epoch 10/10, batch 1000. loss: 1.636. | |
Finished epoch 10/10, batch 1200. loss: 1.846. | |
Finished epoch 10/10, batch 1400. loss: 1.538. | |
Finished epoch 10. avg loss: 1.5997641355809842; median loss: 1.5899784564971924 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision | |
from torch.utils.data import DataLoader | |
import torch | |
from torch import nn | |
from torch import optim | |
from torch.utils.tensorboard import SummaryWriter | |
import numpy as np | |
import os | |
transform_train = torchvision.transforms.Compose([ | |
torchvision.transforms.RandomHorizontalFlip(), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | |
]) | |
transform_test = torchvision.transforms.Compose([ | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
train_dataset = torchvision.datasets.CIFAR10("/mnt/cifar10/", train=True, transform=transform_train, download=True) | |
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False) | |
test_dataset = torchvision.datasets.CIFAR10("/mnt/cifar10/", train=False, transform=transform_test, download=True) | |
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False) | |
class CIFAR10Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.cnn_block_1 = nn.Sequential(*[ | |
nn.Conv2d(3, 32, 3, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(32, 64, 3, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2) | |
]) | |
self.dropout_1 = nn.Dropout(0.25) | |
self.cnn_block_2 = nn.Sequential(*[ | |
nn.Conv2d(64, 64, 3, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(64, 64, 3, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2) | |
]) | |
self.dropout_2 = nn.Dropout(0.25) | |
self.flatten = lambda inp: torch.flatten(inp, 1) | |
self.linearize = nn.Sequential(*[ | |
nn.Linear(64 * 8 * 8, 512), | |
nn.ReLU() | |
]) | |
self.dropout_3 = nn.Dropout(0.5) | |
self.out = nn.Linear(512, 10) | |
def forward(self, X): | |
X = torch.utils.checkpoint.checkpoint(self.cnn_block_1, X) | |
X = self.dropout_1(X) | |
X = torch.utils.checkpoint.checkpoint(self.cnn_block_2, X) | |
X = self.dropout_2(X) | |
X = self.flatten(X) | |
X = self.linearize(X) | |
X = self.dropout_3(X) | |
X = self.out(X) | |
return X | |
clf = CIFAR10Model() | |
start_epoch = 1 | |
clf.cuda() | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.RMSprop(clf.parameters(), lr=0.0001, weight_decay=1e-6) | |
def train(): | |
clf.train() | |
NUM_EPOCHS = 10 | |
for epoch in range(start_epoch, NUM_EPOCHS + 1): | |
losses = [] | |
for i, (X_batch, y_cls) in enumerate(train_dataloader): | |
optimizer.zero_grad() | |
y = y_cls.cuda() | |
X_batch = X_batch.cuda() | |
y_pred = clf(X_batch) | |
loss = criterion(y_pred, y) | |
loss.backward() | |
optimizer.step() | |
train_loss = loss.item() | |
if i % 200 == 0: | |
print( | |
f'Finished epoch {epoch}/{NUM_EPOCHS}, batch {i}. loss: {train_loss:.3f}.' | |
) | |
losses.append(train_loss) | |
print( | |
f'Finished epoch {epoch}. ' | |
f'avg loss: {np.mean(losses)}; median loss: {np.median(losses)}' | |
) | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment