Created
January 18, 2022 07:21
-
-
Save bakirillov/445f4a795a08c669d480d8ea8034728d to your computer and use it in GitHub Desktop.
A simple Siamese network made with Pytorch Ligthning (with data module class that performs Siamese arrangement of example pairs)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import cv2 | |
import torch | |
import einops | |
import numpy as np | |
import pandas as pd | |
from torch import nn | |
import os.path as op | |
from torch.optim import Adam | |
import pytorch_lightning as pl | |
import torch.nn.functional as F | |
from torchvision import models, transforms | |
from torch.utils.data import DataLoader, Dataset | |
class DS(Dataset): | |
def __init__(self, df, fn_col, id_col, transform=None): | |
self.transform = transform | |
self.df = df | |
self.fn_col = fn_col | |
self.id_col = id_col | |
def __len__(self): | |
return(self.df.shape[0]) | |
def __getitem__(self, ind): | |
current = self.df.iloc[ind] | |
path = current[self.fn_col] | |
label = current[self.id_col] | |
if self.transform: | |
path = self.transform(path) | |
return(path, label) | |
class SiameseArrangement(Dataset): | |
def __init__(self, ds): | |
self.ds = ds | |
def __len__(self): | |
return(len(self.ds)**2) | |
def __getitem__(self, ind): | |
L = len(self.ds) | |
Y = ind//L | |
X = ind - L*Y | |
first_x, first_y = self.ds[Y] | |
second_x, second_y = self.ds[X] | |
return(first_x, second_x, int(first_y == second_y), Y, X) | |
class DM(pl.LightningDataModule): | |
def __init__( | |
self, train_df, test_df, val_df, transform, batch_size=64 | |
): | |
super().__init__() | |
self.train_df = train_df | |
self.test_df = test_df | |
self.val_df = val_df | |
self.transform = transform | |
self.batch_size = batch_size | |
def prepare_data(self): | |
pass | |
def setup(self, stage=None): | |
self.train = SiameseArrangement(DS(self.train_df, self.transform)) | |
self.test = SiameseArrangement(DS(self.test_df, self.transform)) | |
self.val = SiameseArrangement(DS(self.val_df, self.transform)) | |
def train_dataloader(self): | |
self.train_loader = DataLoader( | |
self.train, shuffle=True, batch_size=self.batch_size | |
) | |
return(self.train_loader) | |
def test_dataloader(self): | |
self.test_loader = DataLoader( | |
self.test, shuffle=False, batch_size=self.batch_size | |
) | |
return(self.test_loader) | |
def val_dataloader(self): | |
self.val_loader = DataLoader( | |
self.val, shuffle=False, batch_size=self.batch_size | |
) | |
return(self.val_loader) | |
class Siamese(pl.LightningModule): | |
def __init__(self, margin=2.0): | |
"""Same architecture as in https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch""" | |
super(Siamese, self).__init__() | |
self.margin = margin | |
self.cnn1 = nn.Sequential( | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(3, 4, kernel_size=3), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(4), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(4, 8, kernel_size=3), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(8), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(8, 8, kernel_size=3), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(8), | |
) | |
self.fc1 = nn.Sequential( | |
nn.Linear(8*100*100, 500), | |
nn.ReLU(inplace=True), | |
nn.Linear(500, 500), | |
nn.ReLU(inplace=True), | |
nn.Linear(500, 5) | |
) | |
def once(self, x): | |
x = self.cnn1(x) | |
x = self.fc1(x.reshape(x.shape[0], -1)) | |
return(x) | |
def forward(self, first, second): | |
first_z = self.once(first) | |
second_z = self.once(second) | |
euclidean_distance = F.pairwise_distance(first_z, second_z, keepdim = True) | |
return(euclidean_distance) | |
def configure_optimizers(self): | |
optimizer = Adam(self.parameters()) | |
return(optimizer) | |
def loss(self, ed, label, margin=2.0): | |
lc = torch.mean((1-label)*ed**2+(label)*torch.clamp(margin-ed, min=0.0)**2) | |
return(lc) | |
def training_step(self, train_batch, batch_idx): | |
first, second, y, _, _ = train_batch | |
ed = self(first, second) | |
loss = self.loss( | |
ed, y, self.margin | |
) | |
self.log("train_loss", loss) | |
return(loss) | |
def validation_step(self, val_batch, batch_idx): | |
first, second, y, _, _ = val_batch | |
ed = self(first, second) | |
loss = self.loss( | |
ed, y, self.margin | |
) | |
self.log("val_loss", loss) | |
return(loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment