Skip to content

Instantly share code, notes, and snippets.

@bdutta19
Last active March 27, 2018 15:25
Show Gist options
  • Save bdutta19/77d94c1236922a00755a578c4b1df489 to your computer and use it in GitHub Desktop.
Save bdutta19/77d94c1236922a00755a578c4b1df489 to your computer and use it in GitHub Desktop.
proto-net-omniglot.py
from __future__ import print_function
from PIL import Image
import numpy as np
import tensorflow as tf
import os
import glob
#import matplotlib.pyplot as plt
def conv_block(inputs, out_channels, name='conv'):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
# weights_initializer=tf.truncated_normal_initializer(stddev=0.02))
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)
return conv
def encoder(x, h_dim, z_dim, reuse=False):
with tf.variable_scope('encoder', reuse=reuse):
net = conv_block(x, h_dim, name='conv_1')
net = conv_block(net, h_dim, name='conv_2')
net = conv_block(net, h_dim, name='conv_3')
net = conv_block(net, z_dim, name='conv_4')
net = tf.contrib.layers.flatten(net)
return net
def euclidean_distance(a, b):
# a.shape = N x D
# b.shape = M x D
N, D = tf.shape(a)[0], tf.shape(a)[1]
M = tf.shape(b)[0]
a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
return tf.reduce_mean(tf.square(a - b), axis=2)
n_epochs = 20
n_episodes = 100
n_way = 10
n_shot = 15
n_query = 5
n_examples = 20
im_width, im_height, channels = 28, 28, 1
h_dim = 64
z_dim = 64
root_dir = 'data/omniglot'
train_split_path = os.path.join(root_dir, 'splits', 'train.txt')
with open(train_split_path, 'r') as train_split:
train_classes = [line.rstrip() for line in train_split.readlines()]
n_classes = len(train_classes)
train_dataset = np.zeros([n_classes, n_examples, im_height, im_width], dtype=np.float32)
for i, tc in enumerate(train_classes):
alphabet, character, rotation = tc.split('/')
rotation = float(rotation[3:])
im_dir = os.path.join(root_dir, 'data', alphabet, character)
im_files = sorted(glob.glob(os.path.join(im_dir, '*.png')))
for j, im_file in enumerate(im_files):
im = 1. - np.array(Image.open(im_file).rotate(rotation).resize((im_width, im_height)), np.float32, copy=False)
train_dataset[i, j] = im
print(train_dataset.shape)
x = tf.placeholder(tf.float32, [None, None, im_height, im_width, channels])
q = tf.placeholder(tf.float32, [None, None, im_height, im_width, channels])
x_shape = tf.shape(x)
q_shape = tf.shape(q)
num_classes, num_support = x_shape[0], x_shape[1]
num_queries = q_shape[1]
y = tf.placeholder(tf.int64, [None, None])
y_one_hot = tf.one_hot(y, depth=num_classes)
emb_x = encoder(tf.reshape(x, [num_classes * num_support, im_height, im_width, channels]), h_dim, z_dim)
emb_dim = tf.shape(emb_x)[-1]
emb_x = tf.reduce_mean(tf.reshape(emb_x, [num_classes, num_support, emb_dim]), axis=1)
emb_q = encoder(tf.reshape(q, [num_classes * num_queries, im_height, im_width, channels]), h_dim, z_dim, reuse=True)
dists = euclidean_distance(emb_q, emb_x)
log_p_y = tf.reshape(tf.nn.log_softmax(-dists), [num_classes, num_queries, -1])
ce_loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, log_p_y), axis=-1), [-1]))
acc = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(log_p_y, axis=-1), y)))
train_op = tf.train.AdamOptimizer().minimize(ce_loss)
sess = tf.InteractiveSession()
init_op = tf.global_variables_initializer()
sess.run(init_op)
for ep in range(n_epochs):
for epi in range(n_episodes):
epi_classes = np.random.permutation(n_classes)[:n_way]
support = np.zeros([n_way, n_shot, im_height, im_width], dtype=np.float32)
query = np.zeros([n_way, n_query, im_height, im_width], dtype=np.float32)
for i, epi_cls in enumerate(epi_classes):
selected = np.random.permutation(n_examples)[:n_shot + n_query]
support[i] = train_dataset[epi_cls, selected[:n_shot]]
query[i] = train_dataset[epi_cls, selected[n_shot:]]
support = np.expand_dims(support, axis=-1)
query = np.expand_dims(query, axis=-1)
labels = np.tile(np.arange(n_way)[:, np.newaxis], (1, n_query)).astype(np.uint8)
_, ls, ac = sess.run([train_op, ce_loss, acc], feed_dict={x: support, q: query, y:labels})
if (epi+1) % 50 == 0:
print('[epoch {}/{}, episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(ep+1, n_epochs, epi+1, n_episodes, ls, ac))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment