Skip to content

Instantly share code, notes, and snippets.

@Snailpong
Created April 11, 2022 12:04
Show Gist options
  • Save Snailpong/d7da64095f78fb8b69941767aaa7b2a3 to your computer and use it in GitHub Desktop.
Save Snailpong/d7da64095f78fb8b69941767aaa7b2a3 to your computer and use it in GitHub Desktop.
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from datasets import DogCatDataset
from models import *
class Trainer:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(self.device)
dataset = DogCatDataset()
self.train_dataset, self.val_dataset, _ = random_split(
dataset, [100, 900, 24000]
)
self.train_loader = DataLoader(self.train_dataset, batch_size=8, shuffle=True)
self.val_loader = DataLoader(self.val_dataset, batch_size=64, shuffle=False)
self.criterion = torch.nn.CrossEntropyLoss()
def train_model(self, exp_name, model, optimizer):
print("\n", exp_name)
record_list = []
model.to(self.device)
for epoch in range(10):
model.train()
train_total_loss = 0.0
train_correct = 0
for images, labels in self.train_loader:
images = images.to(self.device)
labels = labels.to(self.device)
outputs = model(images)
predicted = torch.max(outputs, 1)[1]
loss = self.criterion(outputs, labels)
train_total_loss += loss.item()
train_correct += (labels == predicted).sum().cpu()
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
val_total_loss = 0.0
val_correct = 0
with torch.no_grad():
for images, labels in self.val_loader:
images = images.to(self.device)
labels = labels.to(self.device)
outputs = model(images)
predicted = torch.max(outputs, 1)[1]
loss = self.criterion(outputs, labels)
val_total_loss += loss.item()
val_correct += (labels == predicted).sum().cpu()
train_loss = train_total_loss / len(self.train_loader)
train_accuracy = train_correct / len(self.train_dataset)
val_loss = val_total_loss / len(self.val_loader)
val_accuracy = val_correct / len(self.val_dataset)
record_list.append([train_loss, train_accuracy, val_loss, val_accuracy])
print(
f"epoch {epoch+1}:\t",
f"train_loss: {train_loss:.4f},",
f"train_accuracy: {train_accuracy:.4f},",
f"val_loss: {val_loss:.4f},",
f"val_accuracy: {val_accuracy:.4f}",
)
return record_list
def visualize(exp_names, results):
metric_name = ["train_loss", "train_accuracy", "val_loss", "val_accuracy"]
results = np.array(results)
for plot_num, metric_name in enumerate(metric_name):
plt.subplot(2, 2, plot_num + 1)
for exp_num, exp_name in enumerate(exp_names):
plt.plot(np.arange(10) + 1, results[exp_num, :, plot_num], label=exp_name)
plt.legend()
plt.xlabel("Epoch")
plt.ylabel(metric_name)
# plt.title(metric_name)
plt.show()
def train():
exp_names = ["scratch", "tune_all", "tune_classifier"]
model0 = efficientnet_b0_scratch()
model1 = efficientnet_b0_fine_tune_all()
model2 = efficientnet_b0_fine_tune_classifier()
optimizer0 = torch.optim.Adam(model0.parameters(), lr=1e-3)
optimizer1 = torch.optim.Adam(model1.parameters(), lr=1e-3)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-3)
trainer = Trainer()
r0 = trainer.train_model(exp_names[0], model0, optimizer0)
r1 = trainer.train_model(exp_names[1], model1, optimizer1)
r2 = trainer.train_model(exp_names[2], model2, optimizer2)
visualize(exp_names, [r0, r1, r2])
if __name__ == "__main__":
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment