import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
public static class Rating implements Serializable {
private int userId;
private int productId;
private float rating;
private long timestamp;
public Rating() {}
public Rating(
int userId,
int productId,
float rating,
long timestamp
) {
this.userId = userId;
this.productId = productId;
this.rating = rating;
this.timestamp = timestamp;
}
public int getUserId() {
return userId;
}
public int getProductId() {
return productId;
}
public float getRating() {
return rating;
}
public long getTimestamp() {
return timestamp;
}
public static Rating parseRating(String str) {
String[] fields = str.split("::");
if (fields.length != 4) {
throw new
IllegalArgumentException("Each line must contain 4 fields");
}
int userId = Integer.parseInt(fields[0]);
int productId = Integer.parseInt(fields[1]);
float rating = Float.parseFloat(fields[2]);
long timestamp = Long.parseLong(fields[3]);
return new Rating(userId, productId, rating, timestamp);
}
}
JavaRDD<Rating> ratingsRDD = spark
.read()
.textFile("s3a://ourco-product-recs/latest.txt") # <1>
.javaRDD()
.map(Rating::parseRating);
Dataset<Row> ratings =
spark.createDataFrame(ratingsRDD, Rating.class);
Dataset<Row>[] splits =
ratings.randomSplit(new double[]{0.8, 0.2}); # <2>
Dataset<Row> training = splits[0];
ALS als = new ALS()
.setMaxIter(5)
.setRegParam(0.01)
.setUserCol("userId")
.setItemCol("productId")
.setRatingCol("rating");
ALSModel model = als.fit(training);
model.setColdStartStrategy("drop"); # <3>
final int NUM_RECOMMENDATIONS = 5;
Dataset<Row> userRecs = model.recommendForUserSubset(
activelyShoppingUsers,
NUM_RECOMMENDATIONS
); # <4>
Created
January 21, 2020 17:17
-
-
Save nathanleclaire/f849ca9f32a78f439f9ee3f961bcfb25 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.Serializable; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.ml.evaluation.RegressionEvaluator; | |
import org.apache.spark.ml.recommendation.ALS; | |
import org.apache.spark.ml.recommendation.ALSModel; | |
public static class Rating implements Serializable { | |
private int userId; | |
private int productId; | |
private float rating; | |
private long timestamp; | |
public Rating() {} | |
public Rating( | |
int userId, | |
int productId, | |
float rating, | |
long timestamp | |
) { | |
this.userId = userId; | |
this.productId = productId; | |
this.rating = rating; | |
this.timestamp = timestamp; | |
} | |
public int getUserId() { | |
return userId; | |
} | |
public int getProductId() { | |
return productId; | |
} | |
public float getRating() { | |
return rating; | |
} | |
public long getTimestamp() { | |
return timestamp; | |
} | |
public static Rating parseRating(String str) { | |
String[] fields = str.split("::"); | |
if (fields.length != 4) { | |
throw new | |
IllegalArgumentException("Each line must contain 4 fields"); | |
} | |
int userId = Integer.parseInt(fields[0]); | |
int productId = Integer.parseInt(fields[1]); | |
float rating = Float.parseFloat(fields[2]); | |
long timestamp = Long.parseLong(fields[3]); | |
return new Rating(userId, productId, rating, timestamp); | |
} | |
} | |
JavaRDD<Rating> ratingsRDD = spark | |
.read() | |
.textFile("s3a://ourco-product-recs/latest.txt") | |
.javaRDD() | |
.map(Rating::parseRating); | |
Dataset<Row> ratings = | |
spark.createDataFrame(ratingsRDD, Rating.class); | |
Dataset<Row>[] splits = | |
ratings.randomSplit(new double[]{0.8, 0.2}); | |
Dataset<Row> training = splits[0]; | |
ALS als = new ALS() | |
.setMaxIter(5) | |
.setRegParam(0.01) | |
.setUserCol("userId") | |
.setItemCol("productId") | |
.setRatingCol("rating"); | |
ALSModel model = als.fit(training); | |
model.setColdStartStrategy("drop"); | |
final int NUM_RECOMMENDATIONS = 5; | |
Dataset<Row> userRecs = model.recommendForUserSubset( | |
activelyShoppingUsers, | |
NUM_RECOMMENDATIONS | |
); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment