Skip to content

Instantly share code, notes, and snippets.

@nathanleclaire
Created January 21, 2020 17:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nathanleclaire/f849ca9f32a78f439f9ee3f961bcfb25 to your computer and use it in GitHub Desktop.
Save nathanleclaire/f849ca9f32a78f439f9ee3f961bcfb25 to your computer and use it in GitHub Desktop.
import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
public static class Rating implements Serializable {
private int userId;
private int productId;
private float rating;
private long timestamp;
public Rating() {}
public Rating(
int userId,
int productId,
float rating,
long timestamp
) {
this.userId = userId;
this.productId = productId;
this.rating = rating;
this.timestamp = timestamp;
}
public int getUserId() {
return userId;
}
public int getProductId() {
return productId;
}
public float getRating() {
return rating;
}
public long getTimestamp() {
return timestamp;
}
public static Rating parseRating(String str) {
String[] fields = str.split("::");
if (fields.length != 4) {
throw new
IllegalArgumentException("Each line must contain 4 fields");
}
int userId = Integer.parseInt(fields[0]);
int productId = Integer.parseInt(fields[1]);
float rating = Float.parseFloat(fields[2]);
long timestamp = Long.parseLong(fields[3]);
return new Rating(userId, productId, rating, timestamp);
}
}
JavaRDD<Rating> ratingsRDD = spark
.read()
.textFile("s3a://ourco-product-recs/latest.txt")
.javaRDD()
.map(Rating::parseRating);
Dataset<Row> ratings =
spark.createDataFrame(ratingsRDD, Rating.class);
Dataset<Row>[] splits =
ratings.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> training = splits[0];
ALS als = new ALS()
.setMaxIter(5)
.setRegParam(0.01)
.setUserCol("userId")
.setItemCol("productId")
.setRatingCol("rating");
ALSModel model = als.fit(training);
model.setColdStartStrategy("drop");
final int NUM_RECOMMENDATIONS = 5;
Dataset<Row> userRecs = model.recommendForUserSubset(
activelyShoppingUsers,
NUM_RECOMMENDATIONS
);
import java.io.Serializable;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;

public static class Rating implements Serializable {
  private int userId;
  private int productId;
  private float rating;
  private long timestamp;

  public Rating() {}

  public Rating(
    int userId,
    int productId,
    float rating,
    long timestamp
  ) {
    this.userId = userId;
    this.productId = productId;
    this.rating = rating;
    this.timestamp = timestamp;
  }

  public int getUserId() {
    return userId;
  }

  public int getProductId() {
    return productId;
  }

  public float getRating() {
    return rating;
  }

  public long getTimestamp() {
    return timestamp;
  }

  public static Rating parseRating(String str) {
    String[] fields = str.split("::");
    if (fields.length != 4) {
      throw new
      IllegalArgumentException("Each line must contain 4 fields");
    }
    int userId = Integer.parseInt(fields[0]);
    int productId = Integer.parseInt(fields[1]);
    float rating = Float.parseFloat(fields[2]);
    long timestamp = Long.parseLong(fields[3]);
    return new Rating(userId, productId, rating, timestamp);
  }
}

JavaRDD<Rating> ratingsRDD = spark
  .read()
  .textFile("s3a://ourco-product-recs/latest.txt") # <1>
  .javaRDD()
  .map(Rating::parseRating);

Dataset<Row> ratings =
  spark.createDataFrame(ratingsRDD, Rating.class);
Dataset<Row>[] splits =
  ratings.randomSplit(new double[]{0.8, 0.2}); # <2>
Dataset<Row> training = splits[0];

ALS als = new ALS()
  .setMaxIter(5)
  .setRegParam(0.01)
  .setUserCol("userId")
  .setItemCol("productId")
  .setRatingCol("rating");
ALSModel model = als.fit(training);

model.setColdStartStrategy("drop"); # <3>

final int NUM_RECOMMENDATIONS = 5;
Dataset<Row> userRecs = model.recommendForUserSubset(
  activelyShoppingUsers,
  NUM_RECOMMENDATIONS
); # <4>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment