Skip to content

Instantly share code, notes, and snippets.

Last active August 23, 2018 17:37
Show Gist options
  • Save nlgranger/076ad1f7ce3c412a7983b9d1c02bc1b5 to your computer and use it in GitHub Desktop.
Save nlgranger/076ad1f7ce3c412a7983b9d1c02bc1b5 to your computer and use it in GitHub Desktop.
# Implementation of Matching Networks for Tensorflow.
# O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra, and others,
# “Matching networks for one shot learning,” in Advances in Neural Information
# Processing Systems, 2016, pp. 3630–3638.
# Copyright 2018 Nicolas Granger <>
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
def cdist(a, b, metric='euclidean', p=2.):
if metric == 'cityblock':
a = tf.expand_dims(a, 1)
b = tf.expand_dims(b, 0)
return tf.reduce_sum(tf.abs(a - b), axis=2)
if metric == 'euclidean':
a = tf.expand_dims(a, 1)
b = tf.expand_dims(b, 0)
return tf.norm(a - b, 'euclidean', axis=2)
if metric == 'minkovski':
if p < 1.:
raise ValueError("too small p for p-norm")
a = tf.expand_dims(a, 1)
b = tf.expand_dims(b, 0)
return tf.norm(a - b, p, axis=2)
elif metric == 'cosine':
a_norm = tf.norm(a, 'euclidean', axis=1, keepdims=True)
b = tf.transpose(b)
b_norm = tf.norm(b, 'euclidean', axis=0, keepdims=True)
return 1 - tf.matmul(a, b) / (a_norm * b_norm + 1e-7)
raise ValueError("invalid metric")
def shepards(inputs, support, ep_voca_size, shots):
"""Computes log-predictions using shepards method. [Thrun98]_
:param inputs:
input vectors, a matrix of shape [batch_size, feat_size]
:param support:
support vectors, a matrix of shape [ep_voca_size * shots, feat_size], their
labels are assumed to be 0, 1, ... repeated shots times
:param ep_voca_size:
number of distinct classes
:param shots:
number of labeled samples for each class
a matrix of width `ep_voca_size` containing the log-probabilities for each
.. [Thrun98] S. Thrun, “Lifelong Learning Algorithms,” in Learning to Learn,
S. Thrun and L. Pratt, Eds. Boston, MA: Springer US, 1998, pp. 181–209.
with tf.variable_scope("shepards"):
alpha = 1 - cdist(inputs, support, metric='cosine')
alpha = tf.nn.log_softmax(alpha, axis=1)
alpha = tf.reshape(alpha, (inputs.shape[0], shots, ep_voca_size))
alpha = tf.reduce_logsumexp(alpha, axis=1)
return alpha
class MatchingNet:
"""Implementation of Matching Networks [Vinyals2016]_.
[Vinyals2016]_ O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra, and others,
“Matching networks for one shot learning,” in Advances in Neural Information
Processing Systems, 2016, pp. 3630–3638.
.. note ::
LSTM implementation might differs from paper (initialization, peepholes, ...)
def __init__(self, feat_size, f_steps, name=None):
with tf.variable_scope(name or self.__class__.__name__) as scope:
self.scope = scope
self.feat_size = feat_size
if f_steps < 1:
raise ValueError("the number of steps must be greater than 1")
with tf.variable_scope("g") as scope_g:
self.scope_g = scope_g
self.g_lstm_fw = tf.nn.rnn_cell.LSTMCell(feat_size, name="g_lstm_fw")
self.g_lstm_bw = tf.nn.rnn_cell.LSTMCell(feat_size, name="g_lstm_bw")
with tf.variable_scope("f") as scope_f:
self.scope_f = scope_f
self.steps = f_steps
self.f_lstm = tf.nn.rnn_cell.LSTMCell(feat_size, name="f_lstm")
def variables(self):
return self.g_lstm_fw.variables + self.g_lstm_bw.variables \
+ self.f_lstm.variables
def __call__(self, *args, **kwargs):
"""Wraps :func:`call` with variable and name scopes."""
with tf.variable_scope(self.scope, auxiliary_name_scope=False) as scope:
with tf.name_scope(scope.original_name_scope):
return*args, **kwargs)
def g(self, inputs):
with tf.variable_scope(self.scope_g, auxiliary_name_scope=False) as scope_g:
with tf.name_scope(scope_g.original_name_scope):
inputs = tf.expand_dims(inputs, 1) # shape = [batch_size, 1, feat_size]
(h_fw, h_bw), _ = tf.nn.bidirectional_dynamic_rnn(
self.g_lstm_fw, self.g_lstm_bw,
initial_state_fw=self.g_lstm_fw.zero_state(1, tf.float32),
initial_state_bw=self.g_lstm_bw.zero_state(1, tf.float32),
time_major=True) # iterating over batched elements
return h_fw[:, 0, :] + h_bw[:, 0, :] + inputs[:, 0, :]
def f(self, inputs, support_embeddings):
batch_size = inputs.shape[0]
with tf.variable_scope(self.scope_f, auxiliary_name_scope=False) as scope_f:
with tf.name_scope(scope_f.original_name_scope):
c = tf.zeros([batch_size, self.feat_size])
h = tf.zeros([batch_size, self.feat_size])
for k in range(self.steps):
a = tf.nn.softmax(
tf.matmul(h + inputs, tf.transpose(support_embeddings)))
r = tf.matmul(a, support_embeddings)
# the paper is unclear here, possibly a typo in the formula
_, (c, h) = self.f_lstm(r, (c, h + inputs))
return h + inputs
def call(self, inputs, support, ep_voca_size, shots, return_embeddings=False):
"""Returns log-predictions for the input samples given support training shots.
:param inputs:
A matrix containing row vector samples (:math:`f'(\widehat{x})`).
:param support:
A matrix containing row vector support samples (:math:`g'(x_i)`),
the labels are assumed to cycle from one to `ep_voca_size` `shots` times.
:param ep_voca_size:
Vocabulary size or number of distinct labels.
:param shots:
Number of training shots for each label class.
:param return_embeddings:
Also return the conditional embeddings for the inputs and the support.
A matrix of row vectors containing the log-probabilities of each input,
if `return_embeddings` is set, also returns the fully conditional embeddings
of the inputs and the support.
support_embeddings = self.g(support)
inputs_embeddings = self.f(inputs, support_embeddings)
log_predictions = shepards(inputs_embeddings, support_embeddings,
ep_voca_size, shots)
if return_embeddings:
return log_predictions, inputs_embeddings, support_embeddings
return log_predictions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment