Skip to content

Instantly share code, notes, and snippets.

@cjratcliff
Created December 21, 2017 02:13
Show Gist options
  • Save cjratcliff/2b938bf49183f934e140bac869919f24 to your computer and use it in GitHub Desktop.
Save cjratcliff/2b938bf49183f934e140bac869919f24 to your computer and use it in GitHub Desktop.
from __future__ import division
from __future__ import print_function
import argparse
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.model_selection import train_test_split
from utils import get_minibatches_idx
batch_size = 32
eps = 1e-8
np.random.seed(0)
sess = tf.Session()
class Model():
def fit(self,X,y,max_epochs):
for epoch in range(max_epochs):
indices = get_minibatches_idx(len(X), batch_size, shuffle=True)
for index in indices:
batch_x = [X[i] for i in index]
batch_y = [[y[i]] for i in index]
feed_dict = {self.x:batch_x, self.y:batch_y}
_,loss = sess.run([self.train_step,self.loss], feed_dict)
return
def predict(self,X):
pred = sess.run(self.pred, {self.x:X})
return pred
class MLP(Model):
def __init__(self):
self.x = tf.placeholder(tf.float32, [None,1],'x')
self.y = tf.placeholder(tf.float32, [None,1],'y')
self.pred = mlp(self.x)
self.loss = tf.reduce_sum(tf.square(self.pred - self.y)) # L2 loss
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
self.train_step = optimizer.minimize(self.loss)
def smooth_line(x,window_size):
smoothed_x = []
for i in range(window_size,len(x)-window_size):
smoothed_x.append(np.mean(x[i-window_size:i+window_size]))
return smoothed_x
def flatten_and_concat(x):
# x is a list of tensors
output = [tf.reshape(i,[-1,1]) for i in x]
return tf.concat(output,axis=0)
def mlp(x):
h = tf.contrib.layers.fully_connected(x, 500, tf.nn.relu)
h = tf.contrib.layers.fully_connected(h, 2500, tf.nn.relu)
h = tf.contrib.layers.fully_connected(h, 2500, tf.nn.relu)
return tf.contrib.layers.fully_connected(h, 1, tf.identity)
class RegularizedMLP():
def __init__(self):
self.train_x = tf.placeholder(tf.float32, [None,1],'train_x')
self.train_y = tf.placeholder(tf.float32, [None,1],'train_y')
self.val_x = tf.placeholder(tf.float32, [None,1],'val_x')
self.val_y = tf.placeholder(tf.float32, [None,1],'val_y')
with tf.variable_scope('mlp'):
self.train_pred = mlp(self.train_x)
# Training loss not using the weight decay
self.train_loss = tf.reduce_sum(tf.square(self.train_pred - self.train_y))
train_grads = tf.gradients(self.train_loss,tf.trainable_variables())
weight_decay = sum([tf.reduce_sum(tf.square(i)) for i in tf.trainable_variables()])
with tf.variable_scope('mlp') as scope:
scope.reuse_variables()
val_pred = mlp(self.val_x)
self.val_loss = tf.reduce_sum(tf.square(val_pred - self.val_y))
val_grads = tf.gradients(self.val_loss,tf.trainable_variables())
# Set weight_decay_coef s.t. grads_diff is approximately equal to wd_grads
self.grads_diff = flatten_and_concat([i-j for (i,j) in zip(val_grads,train_grads)])
self.params = flatten_and_concat(tf.trainable_variables())
# From the closed form OLS solution
wd_coef = tf.reduce_sum(self.params*self.grads_diff)
wd_coef /= tf.reduce_sum(tf.square(self.params))
self.wd_coef = wd_coef/2
self.train_loss_wd = self.train_loss + tf.maximum(0.0,self.wd_coef)*weight_decay
optimizer = tf.train.GradientDescentOptimizer(0.001)
self.train_step = optimizer.minimize(self.train_loss_wd)
def fit(self,X,y,max_epochs):
# Split into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.33, random_state=42)
weight_decay_values = []
train_losses = []
val_losses = []
train_losses_wd = []
for epoch in range(max_epochs):
train_indices = get_minibatches_idx(len(X_train), batch_size, shuffle=True)
val_indices = get_minibatches_idx(len(X_val), batch_size, shuffle=True)
for it,iv in zip(train_indices,val_indices):
batch_train_x = [X_train[i] for i in it]
batch_train_y = [[y_train[i]] for i in it]
batch_val_x = [X_val[i] for i in iv]
batch_val_y = [[y_val[i]] for i in iv]
feed_dict = {self.train_x: batch_train_x,
self.train_y: batch_train_y,
self.val_x: batch_val_x,
self.val_y: batch_val_y}
_,loss_wd,train_loss,val_loss,wd = sess.run([self.train_step, self.train_loss_wd, self.train_loss, self.val_loss, self.wd_coef], feed_dict)
print(wd)
train_losses.append(train_loss)
val_losses.append(val_loss)
train_losses_wd.append(loss_wd)
weight_decay_values.append(wd)
# Display metrics
window_size = 10
train_losses = smooth_line(train_losses,window_size)
val_losses = smooth_line(val_losses,window_size)
train_losses_wd = smooth_line(train_losses_wd,window_size)
fig, ax1 = plt.subplots()
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_ylim((0, 10))
ax1.plot(train_losses, color='r', label="Training loss")
ax1.plot(val_losses, color='g', label="Validation loss")
ax1.plot(train_losses_wd, color='k', label="Training loss with weight decay")
plt.legend(loc="best")
ax2 = ax1.twinx()
ax2.set_ylabel('Weight decay coefficient')
ax2.set_ylim((-1, 1))
ax2.plot(weight_decay_values, color='b', label="Weight decay coefficient")
plt.show()
return
def predict(self,X):
feed_dict = {self.train_x:X}
pred = sess.run(self.train_pred, feed_dict)
return pred
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--reg', type=str, help='none or adaptive', required=True)
args = parser.parse_args()
max_epochs = 2000
num_samples = 30
noise_level = 0.1
x_interval = [0,1]
true_fun = lambda X: np.cos(1.5 * np.pi * X)
X_train = np.sort(np.random.uniform(x_interval[0], x_interval[1], num_samples))
y_train = true_fun(X_train) + np.random.randn(num_samples) * noise_level
X_test = np.linspace(x_interval[0], x_interval[1], 100)
if args.reg == 'none':
m = MLP()
elif args.reg == 'adaptive':
m = RegularizedMLP()
else:
raise ValueError('Invalid reg argument')
sess.run(tf.global_variables_initializer())
m.fit(X_train[:, np.newaxis], y_train, max_epochs)
pred = m.predict(X_test[:, np.newaxis])
plt.plot(X_test, pred, label="Model")
plt.plot(X_test, true_fun(X_test), label="True function")
plt.scatter(X_train, y_train, label="Samples")
plt.xlabel("x")
plt.ylabel("y")
plt.xlim(x_interval)
plt.ylim((-2, 2))
plt.legend(loc="best")
plt.show()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment