Skip to content

Instantly share code, notes, and snippets.

@Anjum48
Last active January 16, 2019 06:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Anjum48/0ad193d4f408346c47533b835e86e10c to your computer and use it in GitHub Desktop.
Save Anjum48/0ad193d4f408346c47533b835e86e10c to your computer and use it in GitHub Desktop.
PyTorch crashes
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import sys
sys.path.append("/home/anjum/PycharmProjects/kaggle")
from sklearn.metrics import pairwise_distances, pairwise_distances_argmin
from humpback_whales.losses import ContrastiveLoss
from humpback_whales.datasets import SiameseDataset, BasicImageIterator
from datetime import datetime
from time import time
INPUT_DIR = "/mnt/storage/kaggle_data/humpback_whales/"
OUTPUT_DIR = "/mnt/storage/kaggle_output/humpback_whales/"
TIMESTAMP = datetime.now().strftime("%Y%m%d-%H%M%S")
N_WORKERS = 4
GREYSCALE = True
torch.manual_seed(42)
class SiameseNetwork(nn.Module):
def __init__(self, embedding_size=128):
super(SiameseNetwork, self).__init__()
self.net = models.resnet50(pretrained=True)
self.net.fc = nn.Linear(self.net.fc.in_features, embedding_size)
def forward(self, image1, image2):
output1 = self.net(image1)
output2 = self.net(image2)
return output1, output2
def get_embedding(self, x):
return self.net(x)
def train(net, optimizer, criterion, device, epochs=20, validate=True):
train_transforms = transforms.Compose([transforms.ToPILImage(),
transforms.RandomAffine(degrees=10, translate=(0.0, 0.0),
scale=(0.95, 1.05), shear=5),
# transforms.Resize((224, 224)),
transforms.ToTensor()])
train_dataset = SiameseDataset(pd.read_csv(os.path.join(INPUT_DIR, "train_multiple.csv")),
os.path.join(INPUT_DIR, "train_cropped"),
transform=train_transforms, greyscale=GREYSCALE)
valid_dataset = BasicImageIterator(pd.read_csv(os.path.join(INPUT_DIR, "valid_multiple.csv")),
os.path.join(INPUT_DIR, "train_cropped"),
transform=train_transforms, greyscale=GREYSCALE)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=N_WORKERS)
valid_loader = DataLoader(valid_dataset, batch_size=32, num_workers=N_WORKERS)
start = time()
iteration_number = 0
train_loss, train_counter = [], []
valid_acc, valid_counter = [], []
for epoch in range(epochs):
# Main training loop
net.train()
train_embeddings, train_labels = [], []
for i, batch in enumerate(train_loader):
image_a, image_b, label = batch["image_a"], batch["image_b"], batch["target"]
image_a, image_b, label = image_a.to(device), image_b.to(device), label.to(device)
optimizer.zero_grad()
embedding_a, embedding_b = net(image_a, image_b)
loss = criterion(embedding_a, embedding_b, label)
loss.backward()
optimizer.step()
iteration_number += 1
train_embeddings.append(embedding_a)
train_embeddings.append(embedding_b)
train_labels.append(batch["label_a"])
train_labels.append(batch["label_b"])
if i % 100 == 0:
print("[%2d, %3d] Loss: %.4f" % (epoch, i, loss))
train_counter.append(iteration_number)
train_loss.append(loss.item())
# Validation
net.eval()
with torch.no_grad():
valid_embeddings, valid_labels = [], []
for batch in valid_loader:
image = batch["image"].to(device)
valid_embeddings.append(net.get_embedding(image))
valid_labels.append(batch["label"])
train_embeddings = torch.cat(train_embeddings).cpu()
valid_embeddings = torch.cat(valid_embeddings).cpu()
train_labels = np.concatenate(train_labels)
valid_labels = np.concatenate(valid_labels)
valid_pred_index = pairwise_distances_argmin(valid_embeddings, train_embeddings)
valid_pred_labels = train_labels[valid_pred_index].flatten()
validation_accuracy = (valid_pred_labels == valid_labels).sum() / len(valid_labels)
print("Validation accuracy: %.5f" % validation_accuracy)
valid_acc.append(validation_accuracy)
valid_counter.append(iteration_number)
print("Trained in %.2f" % ((time() - start) / 60))
plt.plot(valid_counter, valid_acc)
plt.plot(train_counter, train_loss)
plt.title("Final accuracy: %.5f" % np.mean(valid_acc[-5:]))
plt.savefig(os.path.join(OUTPUT_DIR, TIMESTAMP + "_training.png"))
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = SiameseNetwork()
net = net.to(device)
criterion = ContrastiveLoss(margin=1)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
train(net, optimizer, criterion, device, epochs=200)
torch.save(net.state_dict(), os.path.join(OUTPUT_DIR, TIMESTAMP + "_siamese_resnet50.pt"))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment