Last active
May 10, 2017 17:51
-
-
Save ahwillia/4fd6c28826473d8b9af82a9cd43821d8 to your computer and use it in GitHub Desktop.
Randomized Matrix Factorization in TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf # works on version 1.0.0 | |
import numpy as np | |
from tqdm import trange | |
# create fake data (low-rank matrix X) | |
A = np.random.randn(100, 3).astype(np.float32) | |
B = np.random.randn(3, 100).astype(np.float32) | |
X = np.dot(A, B) | |
# create tensorflow variables to predict low-rank decomposition | |
Ahat = tf.Variable(tf.random_normal(A.shape)) | |
Bhat = tf.Variable(tf.random_normal(B.shape)) | |
_X = tf.constant(X) | |
# helper function for indexing matrix (as of now tensorflow does not have a | |
def gather_along_axis(data, indices, axis=0): | |
"""gathers indices of along specified axis. | |
Args | |
---- | |
data : Tensor, object to gather indices from | |
indices : 1D Tensor of ints, indices into data tensor | |
axis : axis of Tensor to gather along | |
Returns | |
------- | |
sampled_data : Tensor, subsample of data tensor | |
""" | |
if not axis: | |
return tf.gather(data, indices) | |
rank = data.get_shape().ndims | |
perm = [axis] + list(range(1, axis)) + [0] + list(range(axis + 1, rank)) | |
sampled_data = tf.transpose(tf.gather(tf.transpose(data, perm), indices), perm) | |
return sampled_data | |
# on each iteration randomly sample 50 rows and 50 columns of X | |
rows = tf.random_uniform((50,), maxval=100, dtype=tf.int32) | |
cols = tf.random_uniform((50,), maxval=100, dtype=tf.int32) | |
# sampled versions of X, Ahat, Bhat | |
X_smp = gather_along_axis(tf.gather(_X, rows), cols, axis=1) | |
A_smp = tf.gather(Ahat, rows) | |
B_smp = gather_along_axis(Bhat, cols, axis=1) | |
# mean squared residual loss function | |
resid = X_smp - tf.matmul(A_smp, B_smp) | |
loss = tf.reduce_mean(resid**2) | |
# training parameters | |
lr = 1e-2 | |
optimizer = tf.train.AdamOptimizer(lr) | |
train_step = optimizer.minimize(loss) | |
# train for 1000 iterations | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
f_hist = [sess.run([loss, train_step])[0] for _ in trange(1000)] | |
# plot learning curve | |
import matplotlib.pyplot as plt | |
plt.plot(f_hist, '.') | |
plt.show() |
Author
ahwillia
commented
May 10, 2017
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment