SVD Recommendations using Tensorflow
import numpy as np | |
import tensorflow as tf | |
# Set random seed for reproducibility | |
np.random.seed(1000) | |
nb_users = 5000 | |
nb_products = 2000 | |
nb_factors = 500 | |
max_rating = 5 | |
nb_rated_products = 500 | |
top_k_products = 10 | |
# Create a random User-Item matrix | |
uim = np.zeros((nb_users, nb_products), dtype=np.float32) | |
for i in range(nb_users): | |
nbp = np.random.randint(0, nb_products, size=nb_rated_products) | |
for j in nbp: | |
uim[i, j] = np.random.randint(1, max_rating+1) | |
# Create a Tensorflow graph | |
graph = tf.Graph() | |
with graph.as_default(): | |
# User-item matrix | |
user_item_matrix = tf.placeholder(tf.float32, shape=(nb_users, nb_products)) | |
# SVD | |
St, Ut, Vt = tf.svd(user_item_matrix) | |
# Compute reduced matrices | |
Sk = tf.diag(St)[0:nb_factors, 0:nb_factors] | |
Uk = Ut[:, 0:nb_factors] | |
Vk = Vt[0:nb_factors, :] | |
# Compute Su and Si | |
Su = tf.matmul(Uk, tf.sqrt(Sk)) | |
Si = tf.matmul(tf.sqrt(Sk), Vk) | |
# Compute user ratings | |
ratings_t = tf.matmul(Su, Si) | |
# Pick top k suggestions | |
best_ratings_t, best_items_t = tf.nn.top_k(ratings_t, top_k_products) | |
# Create Tensorflow session | |
session = tf.InteractiveSession(graph=graph) | |
# Compute the top k suggestions for all users | |
feed_dict = { | |
user_item_matrix: uim | |
} | |
best_items = session.run([best_items_t], feed_dict=feed_dict) | |
# Suggestions for user 1000, 1010 | |
for i in range(1000, 1010): | |
print('User {}: {}'.format(i, best_items[0][i])) |
for i in range(1000, 1010): | |
print('User {}: {}'.format(i, best_items[0][i])) | |
User 1000: [ 412 867 1040 509 1311 1562 758 1796 636 556] | |
User 1001: [ 548 88 1299 175 81 1837 282 1555 1796 1902] | |
User 1002: [ 433 667 460 821 1762 775 1673 278 284 1540] | |
User 1003: [1823 602 1874 43 1979 1612 1755 857 891 1701] | |
User 1004: [1700 1312 892 621 194 1919 196 1746 1697 1192] | |
User 1005: [ 891 221 1112 1387 768 1697 916 485 1673 1515] | |
User 1006: [ 463 611 1986 1253 175 1362 1112 1811 1045 768] | |
User 1007: [1170 70 1886 757 412 606 892 1772 1540 1415] | |
User 1008: [ 757 855 509 329 410 1304 1900 1631 476 284] | |
User 1009: [ 888 1654 6 1453 735 1745 505 422 1878 1965] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This comment has been minimized.
hello Giuseppe Bonaccorso,
I want to know what is dataset? (the structure of data).
I 'm newbie. I want to put the real data in to it.
Thank you.