Skip to content

Instantly share code, notes, and snippets.

@EndingCredits
Created May 11, 2018 12:40
Show Gist options
  • Save EndingCredits/305d6745d74cd2c753b5e36c8627ace3 to your computer and use it in GitHub Desktop.
Save EndingCredits/305d6745d74cd2c753b5e36c8627ace3 to your computer and use it in GitHub Desktop.
Demonstration of backprop through tf.matrix_solve_ls to learn features for random data
from __future__ import division
import numpy as np
import tensorflow as tf
# Launch the graph
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
NUM_EXAMPLES = 8096
EXAMPLE_DIM = 128
DATA_FEATURES = 128
STEPS = 10000
KARPATHY_CONST = 0.00025
# Generate some random data
# data = n x m n = num examples (batch size), m = number features (not really important)
# labels = n x c n = num examples (batch size), c = number classes
data = tf.get_variable('x', [NUM_EXAMPLES, DATA_FEATURES], tf.float32,
tf.random_normal_initializer(stddev=1.0), trainable=False)
labels = tf.Variable(np.random.randint(2, size=[NUM_EXAMPLES, 1])*2-1, name='y', dtype=tf.float32, trainable=False)
# Weights for first layer of matrix
# weigths = m x d m = number features of data, d = dimension to solve
W1 = tf.get_variable('m1', [DATA_FEATURES, EXAMPLE_DIM], tf.float32,
tf.random_normal_initializer(stddev=1.0))
# Get output of first hidden layer
# shape should be n x d
h1 = tf.nn.relu(tf.matmul(data, W1))
# Solve LS for labels to get W2
W2 = tf.matrix_solve_ls(h1, labels)
# Get output layer according to solution
y = tf.matmul(h1, W2)
# Square Loss
loss = tf.reduce_sum((y - labels)**2)
# Optimiser
optim = tf.train.AdamOptimizer(KARPATHY_CONST*4).minimize(loss)
# Initialize all tensorflow variables
sess.run(tf.global_variables_initializer())
try:
from tqdm import tqdm
# Train W1 for 1000 'batches'
for step in tqdm(range(STEPS)):
_, l = sess.run( [optim, loss] )
if step % (STEPS//100) == 0:
tqdm.write( str( l ) )
except:
# for tqdm-less plebs
for step in range(STEPS):
_, l = sess.run( [optim, loss] )
if step % (STEPS//100) == 0:
print( str( l ) )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment