Skip to content

Instantly share code, notes, and snippets.

@joshua-feldman
Created December 9, 2019 10:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joshua-feldman/9e6b3ec5ce7ed68e4e68f40de90c911f to your computer and use it in GitHub Desktop.
Save joshua-feldman/9e6b3ec5ce7ed68e4e68f40de90c911f to your computer and use it in GitHub Desktop.
Matrix factorisation demo using sparklyr
##############################
# 1. SETUP #
##############################
library(tidyverse)
library(sparklyr)
library(formattable)
set.seed(1234)
customGreen <- "#71CA97"
customRed <- "#ff7f7f"
user <- c(rep(1, 5), rep(2, 5), rep(3, 5), rep(4, 5), rep(5, 5), rep(6, 5))
item <- c(1, 3, 4, 5, 6,
1, 2, 4, 5, 6,
2, 3, 4, 5, 6,
1, 2, 3, 4, 6,
1, 2, 3, 4, 5,
1, 2, 3, 5, 6)
rating <- c(5, 5, 1, 1, 1,
5, 5, 1, 1, 1,
5, 5, 1, 1, 1,
1, 1, 1, 5, 5,
1, 1, 1, 5, 5,
1, 1, 1, 5, 5)
df <- data.frame(user, item, rating)
sc <- spark_connect(master = "local")
df_tbl <- copy_to(sc, df, overwrite = TRUE)
model <- ml_als(df_tbl, rating ~ user + item, rank = 2)
##############################
# 2. GRAPHICS #
##############################
# 2.1. USER-ITEM MATRIX
df_table <- df %>%
mutate(item = str_replace_all(item, "1", "Airplane")) %>%
mutate(item = str_replace_all(item, "2", "Bridesmaids")) %>%
mutate(item = str_replace_all(item, "3", "Superbad")) %>%
mutate(item = str_replace_all(item, "4", "Halloween")) %>%
mutate(item = str_replace_all(item, "5", "Psycho")) %>%
mutate(item = str_replace_all(item, "6", "Scream")) %>%
mutate(user = str_replace_all(user, "1", "Amy")) %>%
mutate(user = str_replace_all(user, "2", "Ben")) %>%
mutate(user = str_replace_all(user, "3", "Chloe")) %>%
mutate(user = str_replace_all(user, "4", "Daniel")) %>%
mutate(user = str_replace_all(user, "5", "Emily")) %>%
mutate(user = str_replace_all(user, "6", "Fred")) %>%
mutate(item = factor(item, levels = c("Airplane", "Bridesmaids", "Superbad",
"Halloween", "Psycho", "Scream"))) %>%
spread(item, rating)
colnames(df_table)[1] <- " "
formattable(df_table,
list(" " = formatter("span", style = ~ style(font.weight = "bold")),
`Airplane`= color_tile(customRed, customGreen),
`Bridesmaids`= color_tile(customRed, customGreen),
`Superbad`= color_tile(customRed, customGreen),
`Halloween`= color_tile(customRed, customGreen),
`Psycho`= color_tile(customRed, customGreen),
`Scream`= color_tile(customRed, customGreen))
)
# 2.2. ITEM MATRIX
item_matrix <- model$model$item_factors %>%
as.data.frame() %>%
select(features_1, features_2) %>%
mutate(features_1 = round(as.numeric(features_1), 2)) %>%
mutate(features_2 = round(as.numeric(features_2), 2)) %>%
t() %>%
as.data.frame()
rownames(item_matrix) <- c("Latent Factor #1", "Latent Factor #2")
colnames(item_matrix) <- colnames(df_table[,2:ncol(df_table)])
item_matrix <- item_matrix %>%
tibble::rownames_to_column(var = " ")
formattable(item_matrix,
list(" " = formatter("span", style = ~ style(font.weight = "bold")),
`Airplane`= color_tile(customRed, customGreen),
`Bridesmaids`= color_tile(customRed, customGreen),
`Superbad`= color_tile(customRed, customGreen),
`Halloween`= color_tile(customRed, customGreen),
`Psycho`= color_tile(customRed, customGreen),
`Scream`= color_tile(customRed, customGreen))
)
# 2.3. USER MATRIX
user_matrix <- model$model$user_factors %>%
as.data.frame() %>%
select(features_1, features_2) %>%
mutate(features_1 = round(as.numeric(features_1), 2)) %>%
mutate(features_2 = round(as.numeric(features_2), 2))
colnames(user_matrix) <- c("Latent Factor #1", "Latent Factor #2")
rownames(user_matrix) <- df_table$` `
user_matrix <- user_matrix %>%
tibble::rownames_to_column(var = " ")
formattable(user_matrix,
list(" " = formatter("span", style = ~ style(font.weight = "bold")),
`Latent Factor #1`= color_tile(customRed, customGreen),
`Latent Factor #2`= color_tile(customRed, customGreen))
)
# 2.4. ESTIMATED USER-ITEM MATRIX
item <- model$model$item_factors %>%
as.data.frame() %>%
select(features_1, features_2) %>%
t()
user <- model$model$user_factors %>%
as.data.frame() %>%
select(features_1, features_2) %>%
as.matrix()
estimated <- user %*% item %>%
as.data.frame() %>%
mutate(V1 = round(as.numeric(V1, 2))) %>%
mutate(V2 = round(as.numeric(V2, 2))) %>%
mutate(V3 = round(as.numeric(V3, 2))) %>%
mutate(V4 = round(as.numeric(V4, 2))) %>%
mutate(V5 = round(as.numeric(V5, 2))) %>%
mutate(V6 = round(as.numeric(V6, 2)))
rownames(estimated) <- df_table$` `
colnames(estimated) <- colnames(df_table[,2:ncol(df_table)])
estimated <- estimated %>%
tibble::rownames_to_column(var = " ")
formattable(estimated,
list(" " = formatter("span", style = ~ style(font.weight = "bold")),
`Airplane`= color_tile(customRed, customGreen),
`Bridesmaids`= color_tile(customRed, customGreen),
`Superbad`= color_tile(customRed, customGreen),
`Halloween`= color_tile(customRed, customGreen),
`Psycho`= color_tile(customRed, customGreen),
`Scream`= color_tile(customRed, customGreen))
)
##############################
# 3. DISCONNECT #
##############################
spark_disconnect_all()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment