Skip to content

Instantly share code, notes, and snippets.

@Jean-Reinhold
Created September 15, 2022 19:34
Show Gist options
  • Save Jean-Reinhold/180c8f526c9a3cda251e3216b03a1b7e to your computer and use it in GitHub Desktop.
Save Jean-Reinhold/180c8f526c9a3cda251e3216b03a1b7e to your computer and use it in GitHub Desktop.
Siamese Neural Net with constrastive loss
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
class Siamese(nn.Module):
def __init__(self, shape: int):
super(Siamese, self).__init__()
self.liner = nn.Sequential(
nn.Linear(shape, 128),
nn.Dropout(p=0.5),
nn.ReLU(),
nn.Linear(128, 32),
nn.Dropout(p=0.5),
nn.ReLU(),
)
self.out = nn.Linear(32, 1)
def forward_one(self, x):
return self.liner(x)
def forward(self, x1, x2):
out1 = self.forward_one(x1)
out2 = self.forward_one(x2)
return out1, out2
class ContrastiveLoss(torch.nn.Module):
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
loss_contrastive = torch.mean(
(1 - label) * torch.pow(euclidean_distance, 2)
+ (label)
* torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
)
return loss_contrastive
print(Siamese(shape=1024))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment