Skip to content

Instantly share code, notes, and snippets.

@zygmuntz
Last active May 1, 2020 19:20
Show Gist options
  • Save zygmuntz/30e6a72e13ecf9b26fddf7cc10204847 to your computer and use it in GitHub Desktop.
Save zygmuntz/30e6a72e13ecf9b26fddf7cc10204847 to your computer and use it in GitHub Desktop.
A Siamese network example modified to use weighted L1 distance and cross-entropy loss.
#!/usr/bin/env python
"""
A Siamese network example modified to use weighted L1 distance and cross-entropy loss, as in
Siamese Neural Networks for One-shot Image Recognition
http://www.cs.toronto.edu/~rsalakhu/papers/oneshot1.pdf
"""
import random
import numpy as np
from __future__ import absolute_import
from __future__ import print_function
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Input, Lambda
from keras.optimizers import SGD, RMSprop
from keras import backend as K
from sklearn.metrics import accuracy_score as accuracy
#
def get_abs_diff(vects):
x, y = vects
return K.abs(x - y)
def abs_diff_output_shape(shapes):
shape1, shape2 = shapes
return shape1
def create_pairs(x, digit_indices):
'''Positive and negative pair creation.
Alternates between positive and negative pairs.
'''
pairs = []
labels = []
n = min([len(digit_indices[d]) for d in range(10)]) - 1
for d in range(10):
for i in range(n):
z1, z2 = digit_indices[d][i], digit_indices[d][i+1]
pairs += [[x[z1], x[z2]]]
inc = random.randrange(1, 10)
dn = (d + inc) % 10
z1, z2 = digit_indices[d][i], digit_indices[dn][i]
pairs += [[x[z1], x[z2]]]
labels += [1, 0]
return np.array(pairs), np.array(labels)
def create_base_network(input_dim):
'''Base network to be shared (eq. to feature extraction).
'''
seq = Sequential()
seq.add(Dense(128, input_shape=(input_dim,), activation='relu'))
seq.add(Dropout(0.1))
seq.add(Dense(128, activation='relu'))
seq.add(Dropout(0.1))
seq.add(Dense(128, activation='relu'))
return seq
#
np.random.seed(1337) # for reproducibility
# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
input_dim = 784
nb_epoch = 20
# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(10)]
tr_pairs, tr_y = create_pairs(X_train, digit_indices)
digit_indices = [np.where(y_test == i)[0] for i in range(10)]
te_pairs, te_y = create_pairs(X_test, digit_indices)
# network definition
base_network = create_base_network(input_dim)
input_a = Input(shape=(input_dim,))
input_b = Input(shape=(input_dim,))
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
abs_diff = Lambda(get_abs_diff, output_shape = abs_diff_output_shape)([processed_a, processed_b])
flattened_weighted_distance = Dense(1, activation = 'sigmoid')(abs_diff)
model = Model(input=[input_a, input_b], output = flattened_weighted_distance)
# train
rms = RMSprop()
model.compile(loss = 'binary_crossentropy', optimizer=rms, metrics = ['accuracy'])
model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y),
batch_size=128, nb_epoch=nb_epoch)
# compute final accuracy on training and test sets
tr_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = accuracy(tr_y, tr_pred.round())
te_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = accuracy(te_y, te_pred.round())
print('* Accuracy on the training set: {:.2%}'.format(tr_acc))
print('* Accuracy on the test set: {:.2%}'.format(te_acc))
@damondpham
Copy link

damondpham commented May 1, 2020

Thanks, this is really helpful!

I'm wondering if it is a problem that bias is included for the last layer? Your code says:

flattened_weighted_distance = Dense(1, activation = 'sigmoid')(abs_diff)

But the formula in the paper looks like:

sigmoid( sum(alpha_i * abs_diff_i) )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment