Last active
January 16, 2019 06:02
-
-
Save Anjum48/0ad193d4f408346c47533b835e86e10c to your computer and use it in GitHub Desktop.
PyTorch crashes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import os | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import models, transforms | |
import sys | |
sys.path.append("/home/anjum/PycharmProjects/kaggle") | |
from sklearn.metrics import pairwise_distances, pairwise_distances_argmin | |
from humpback_whales.losses import ContrastiveLoss | |
from humpback_whales.datasets import SiameseDataset, BasicImageIterator | |
from datetime import datetime | |
from time import time | |
INPUT_DIR = "/mnt/storage/kaggle_data/humpback_whales/" | |
OUTPUT_DIR = "/mnt/storage/kaggle_output/humpback_whales/" | |
TIMESTAMP = datetime.now().strftime("%Y%m%d-%H%M%S") | |
N_WORKERS = 4 | |
GREYSCALE = True | |
torch.manual_seed(42) | |
class SiameseNetwork(nn.Module): | |
def __init__(self, embedding_size=128): | |
super(SiameseNetwork, self).__init__() | |
self.net = models.resnet50(pretrained=True) | |
self.net.fc = nn.Linear(self.net.fc.in_features, embedding_size) | |
def forward(self, image1, image2): | |
output1 = self.net(image1) | |
output2 = self.net(image2) | |
return output1, output2 | |
def get_embedding(self, x): | |
return self.net(x) | |
def train(net, optimizer, criterion, device, epochs=20, validate=True): | |
train_transforms = transforms.Compose([transforms.ToPILImage(), | |
transforms.RandomAffine(degrees=10, translate=(0.0, 0.0), | |
scale=(0.95, 1.05), shear=5), | |
# transforms.Resize((224, 224)), | |
transforms.ToTensor()]) | |
train_dataset = SiameseDataset(pd.read_csv(os.path.join(INPUT_DIR, "train_multiple.csv")), | |
os.path.join(INPUT_DIR, "train_cropped"), | |
transform=train_transforms, greyscale=GREYSCALE) | |
valid_dataset = BasicImageIterator(pd.read_csv(os.path.join(INPUT_DIR, "valid_multiple.csv")), | |
os.path.join(INPUT_DIR, "train_cropped"), | |
transform=train_transforms, greyscale=GREYSCALE) | |
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=N_WORKERS) | |
valid_loader = DataLoader(valid_dataset, batch_size=32, num_workers=N_WORKERS) | |
start = time() | |
iteration_number = 0 | |
train_loss, train_counter = [], [] | |
valid_acc, valid_counter = [], [] | |
for epoch in range(epochs): | |
# Main training loop | |
net.train() | |
train_embeddings, train_labels = [], [] | |
for i, batch in enumerate(train_loader): | |
image_a, image_b, label = batch["image_a"], batch["image_b"], batch["target"] | |
image_a, image_b, label = image_a.to(device), image_b.to(device), label.to(device) | |
optimizer.zero_grad() | |
embedding_a, embedding_b = net(image_a, image_b) | |
loss = criterion(embedding_a, embedding_b, label) | |
loss.backward() | |
optimizer.step() | |
iteration_number += 1 | |
train_embeddings.append(embedding_a) | |
train_embeddings.append(embedding_b) | |
train_labels.append(batch["label_a"]) | |
train_labels.append(batch["label_b"]) | |
if i % 100 == 0: | |
print("[%2d, %3d] Loss: %.4f" % (epoch, i, loss)) | |
train_counter.append(iteration_number) | |
train_loss.append(loss.item()) | |
# Validation | |
net.eval() | |
with torch.no_grad(): | |
valid_embeddings, valid_labels = [], [] | |
for batch in valid_loader: | |
image = batch["image"].to(device) | |
valid_embeddings.append(net.get_embedding(image)) | |
valid_labels.append(batch["label"]) | |
train_embeddings = torch.cat(train_embeddings).cpu() | |
valid_embeddings = torch.cat(valid_embeddings).cpu() | |
train_labels = np.concatenate(train_labels) | |
valid_labels = np.concatenate(valid_labels) | |
valid_pred_index = pairwise_distances_argmin(valid_embeddings, train_embeddings) | |
valid_pred_labels = train_labels[valid_pred_index].flatten() | |
validation_accuracy = (valid_pred_labels == valid_labels).sum() / len(valid_labels) | |
print("Validation accuracy: %.5f" % validation_accuracy) | |
valid_acc.append(validation_accuracy) | |
valid_counter.append(iteration_number) | |
print("Trained in %.2f" % ((time() - start) / 60)) | |
plt.plot(valid_counter, valid_acc) | |
plt.plot(train_counter, train_loss) | |
plt.title("Final accuracy: %.5f" % np.mean(valid_acc[-5:])) | |
plt.savefig(os.path.join(OUTPUT_DIR, TIMESTAMP + "_training.png")) | |
def main(): | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
net = SiameseNetwork() | |
net = net.to(device) | |
criterion = ContrastiveLoss(margin=1) | |
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | |
train(net, optimizer, criterion, device, epochs=200) | |
torch.save(net.state_dict(), os.path.join(OUTPUT_DIR, TIMESTAMP + "_siamese_resnet50.pt")) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment