Skip to content

Instantly share code, notes, and snippets.

@Aankhen
Created September 1, 2018 17:25
Show Gist options
  • Save Aankhen/cb6c5545823d36cacc548af9ddaad527 to your computer and use it in GitHub Desktop.
Save Aankhen/cb6c5545823d36cacc548af9ddaad527 to your computer and use it in GitHub Desktop.
‘Recommending books (with Rust)’, enhanced
cargo-features = ["edition"]
[package]
name = "goodbooks-recommender"
version = "0.1.0"
authors = ["A"]
edition = "2018"
[dependencies]
reqwest = "0.8.8"
failure = "0.1.2"
serde_derive = "1.0.74"
serde = "1.0.74"
serde_json = "1.0.26"
csv = "1.0.1"
sbr = "0.4.0"
rand = "0.5.5"
elapsed = "0.1.2"
clap = "^2.32.0"
#![feature(uniform_paths)]
use std::collections::HashMap;
use std::fs::File;
use std::io::BufWriter;
use std::path::Path;
use serde_derive::{Deserialize, Serialize};
/// Download file from `url` and save it to `destination`.
fn download(url: impl AsRef<str>, destination: impl AsRef<Path>) -> Result<(), failure::Error> {
let destination = destination.as_ref();
if destination.exists() {
return Ok(());
}
let file = File::create(destination)?;
let mut writer = BufWriter::new(file);
let mut response = reqwest::get(url.as_ref())?;
response.copy_to(&mut writer)?;
Ok(())
}
/// Download ratings and metadata.
fn download_data(
ratings_path: impl AsRef<Path>,
books_path: impl AsRef<Path>,
ratings_url: impl AsRef<str>,
books_url: impl AsRef<str>,
) {
download(ratings_url.as_ref(), ratings_path.as_ref()).expect("Could not download ratings");
download(books_url.as_ref(), books_path.as_ref()).expect("Could not download metadata");
}
#[derive(Debug, Serialize, Deserialize)]
struct WishlistEntry {
user_id: usize,
book_id: usize,
}
fn deserialize_ratings(path: impl AsRef<Path>) -> Result<Vec<WishlistEntry>, failure::Error> {
let mut reader = csv::Reader::from_path(path)?;
let entries = reader.deserialize().collect::<Result<Vec<_>, _>>()?;
Ok(entries)
}
#[derive(Debug, Serialize, Deserialize)]
struct Book {
book_id: usize,
title: String,
}
/// Deserialize from file at `path` into book mappings.
fn deserialize_books(
path: impl AsRef<Path>,
) -> Result<(HashMap<usize, String>, HashMap<String, usize>), failure::Error> {
let mut reader = csv::Reader::from_path(path.as_ref())?;
let entries: Vec<Book> = reader
.deserialize::<Book>()
.collect::<Result<Vec<_>, _>>()?;
let id_to_title: HashMap<usize, String> = entries
.iter()
.map(|book| (book.book_id, book.title.clone()))
.collect();
let title_to_id: HashMap<String, usize> = entries
.iter()
.map(|book| (book.title.clone(), book.book_id))
.collect();
Ok((id_to_title, title_to_id))
}
use sbr::models::ewma::{Hyperparameters, ImplicitEWMAModel};
use sbr::models::{Loss, Optimizer};
fn build_model(num_items: usize) -> ImplicitEWMAModel {
let hp = Hyperparameters::new(num_items, 128)
.embedding_dim(32)
.learning_rate(0.16)
.l2_penalty(0.0004)
.loss(Loss::WARP)
.optimizer(Optimizer::Adagrad)
.num_epochs(10)
.num_threads(1);
hp.build()
}
use sbr::data::{Interaction, Interactions};
fn build_interactions(data: &[WishlistEntry]) -> Interactions {
let num_users = data.iter().map(|x| x.user_id).max().unwrap() + 1;
let num_items = data.iter().map(|x| x.book_id).max().unwrap() + 1;
let mut interactions = Interactions::new(num_users, num_items);
for (idx, datum) in data.iter().enumerate() {
interactions.push(Interaction::new(datum.user_id, datum.book_id, idx));
}
interactions
}
use rand::SeedableRng;
use sbr::data::user_based_split;
use sbr::OnlineRankingModel;
use sbr::evaluation::mrr_score;
/// Fit the model.
///
/// If successful, return the MRR on the test set. Otherwise, return
/// an error.
fn fit(model: &mut ImplicitEWMAModel, data: &Interactions) -> Result<f32, failure::Error> {
let mut rng = rand::XorShiftRng::from_seed([42; 16]);
let (train, test) = user_based_split(data, &mut rng, 0.2);
model.fit(&train.to_compressed())?;
let mrr = mrr_score(model, &test.to_compressed())?;
Ok(mrr)
}
fn serialize_model(
model: &ImplicitEWMAModel,
path: impl AsRef<Path>,
) -> Result<(), failure::Error> {
let file = File::create(path.as_ref())?;
let mut writer = BufWriter::new(file);
Ok(serde_json::to_writer(&mut writer, model)?)
}
use elapsed::measure_time;
/// Download training data and build a model.
///
/// We’ll use this function to power the `fit` subcommand of our
/// command line tool.
fn main_build(
model_path: impl AsRef<Path>,
ratings_path: impl AsRef<Path>,
books_path: impl AsRef<Path>,
ratings_url: impl AsRef<str>,
books_url: impl AsRef<str>,
) {
let model_path = model_path.as_ref();
if model_path.exists() {
println!("Model already fitted.");
return;
}
let ratings_path = ratings_path.as_ref();
let books_path = books_path.as_ref();
println!("Downloading data...");
download_data(ratings_path, books_path, ratings_url, books_url);
let ratings = deserialize_ratings(ratings_path).unwrap();
let (id_to_title, _) = deserialize_books(books_path).unwrap();
println!(
"Deserialized {} ratings and {} books.",
ratings.len(),
id_to_title.len()
);
let interactions = build_interactions(&ratings);
let mut model = build_model(interactions.num_items());
println!("Fitting...");
let (elapsed, mrr) =
measure_time(|| fit(&mut model, &interactions).expect("Unable to fit model"));
println!("Fitted model with MRR of {:.2} in {}.", mrr, elapsed);
serialize_model(&model, model_path).expect("Unable to serialize model.");
}
use std::io::BufReader;
fn deserialize_model(model_path: impl AsRef<Path>) -> Result<ImplicitEWMAModel, failure::Error> {
let file = File::open(model_path.as_ref())?;
let reader = BufReader::new(file);
let model = serde_json::from_reader(reader)?;
Ok(model)
}
use std::iter::Iterator;
fn predict(
books_path: impl AsRef<Path>,
input_titles: &[String],
model: &ImplicitEWMAModel,
) -> Result<Vec<String>, failure::Error> {
let (id_to_title, title_to_id) = deserialize_books(books_path.as_ref()).unwrap();
for title in input_titles {
if !title_to_id.contains_key(title) {
println!("No such title, ignoring: {}", title);
}
}
let input_indices: Vec<_> = input_titles
.iter()
.filter_map(|title| title_to_id.get(title))
.cloned()
.collect();
let indices_to_score: Vec<usize> = (0..id_to_title.len()).collect();
let user = model.user_representation(&input_indices)?;
let predictions = model.predict(&user, &indices_to_score)?;
let mut predictions: Vec<_> = indices_to_score
.iter()
.zip(predictions)
.map(|(idx, score)| (idx, score))
.collect();
predictions.sort_by(|(_, score_a), (_, score_b)| score_b.partial_cmp(score_a).unwrap());
Ok((&predictions[..10])
.iter()
.map(|(idx, _)| id_to_title.get(idx).unwrap())
.cloned()
.collect())
}
use std::ffi::{OsStr, OsString};
fn is_existing_file(val: &OsStr) -> Result<(), OsString> {
let path = Path::new(&val);
if path.exists() {
Ok(())
} else {
Err(OsString::from("Not an existing file"))
}
}
fn main() {
use clap::{App, AppSettings, Arg, SubCommand};
let matches = App::new("Goodbooks Recommender")
.version("0.1.0")
.about("Recommends books using the goodbooks-10k dataset")
.setting(AppSettings::SubcommandRequired)
.subcommand(
SubCommand::with_name("fit")
.about("Fits")
.arg(
Arg::with_name("ratings_url")
.help("URL of ratings data")
.long("ratings-url")
.default_value(
"https://github.com/zygmuntz/goodbooks-10k/raw/master/ratings.csv",
),
).arg(
Arg::with_name("books_url")
.help("URL of books data")
.long("books-url")
.default_value(
"https://github.com/zygmuntz/goodbooks-10k/raw/master/books.csv",
),
).arg(
Arg::with_name("ratings_filename")
.help("Specifies ratings filename")
.long("ratings-filename")
.default_value_os(OsStr::new("ratings.json")),
).arg(
Arg::with_name("model_filename")
.help("Specifies model file path")
.long("model-filename")
.default_value_os(OsStr::new("model.json")),
).arg(
Arg::with_name("books_filename")
.long("books-filename")
.help("Specifies books file path")
.default_value_os(OsStr::new("books.json")),
),
).subcommand(
SubCommand::with_name("predict")
.about("Makes predictions")
.arg(
Arg::with_name("titles")
.help("Titles to base predictions on")
.index(1)
.multiple(true)
.required(true),
).arg(
Arg::with_name("model_filename")
.help("Specifies model file path")
.long("model-filename")
.default_value_os(OsStr::new("model.json"))
.validator_os(is_existing_file),
).arg(
Arg::with_name("books_filename")
.help("Specifies books file path")
.long("books-filename")
.default_value_os(OsStr::new("books.json"))
.validator_os(is_existing_file),
),
).get_matches();
match matches.subcommand() {
("fit", Some(matches)) => {
let ratings_path = Path::new(matches.value_of("ratings_filename").unwrap());
let model_path = Path::new(matches.value_of("model_filename").unwrap());
let books_path = Path::new(matches.value_of("books_filename").unwrap());
let ratings_url = matches.value_of("ratings_url").unwrap();
let books_url = matches.value_of("books_url").unwrap();
main_build(
&model_path,
&ratings_path,
&books_path,
&ratings_url,
&books_url,
)
}
("predict", Some(matches)) => {
let model_path = Path::new(matches.value_of("model_filename").unwrap());
let books_path = Path::new(matches.value_of("books_filename").unwrap());
let model = deserialize_model(&model_path)
.expect(&format!("Unable to deserialize {}.", model_path.display()));
let predictions = predict(
&books_path,
&matches
.values_of("titles")
.unwrap()
.map(|s| s.to_owned())
.collect::<Vec<_>>(),
&model,
).expect("Unable to get predictions");
if predictions.len() == 0 {
println!("No predictions found.")
} else {
println!("Predictions:");
for prediction in predictions {
println!(" {}", prediction);
}
}
}
_ => unreachable!(),
}
}
@Aankhen
Copy link
Author

Aankhen commented Sep 1, 2018

This is based on maciejkula’s blog post, with a few tweaks to try out Rust 2018 features and a few useful crates.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment