Skip to content

Instantly share code, notes, and snippets.

@ramv
Created April 28, 2017 21:27
Show Gist options
  • Save ramv/0093095fa87bef02483488cde44146ce to your computer and use it in GitHub Desktop.
Save ramv/0093095fa87bef02483488cde44146ce to your computer and use it in GitHub Desktop.
Similarity Analysis Example
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.util.Iterator;
/**
* TODO fix some bugs and test with MovieLens data
*/
public class SimilarityAnalysis {
static final Logger LOGGER = LoggerFactory.getLogger(CoOccurence.class);
private static void computeSimilarVideos(JavaRDD<Rating> ratings, JavaSparkContext jsc, String outDir){
/**
* Parameters to regularize correlation.
*/
double PRIOR_COUNT = 10;
double PRIOR_CORRELATION = 0;
SQLContext sqlContext = new SQLContext(jsc);
DataFrame ratingsDf = sqlContext.createDataFrame(ratings, DmRating.class);
// get num raters per movie, keyed on movie id
DataFrame numRatingsPerVideo = ratingsDf.groupBy("modVideoId").count();
numRatingsPerVideo.show();
// join ratings with num raters on movie id
// ratingsWithSize now contains the following fields: (user, movie, rating, numRaters).
DataFrame ratingsWithSize = ratingsDf.join(numRatingsPerVideo);
ratingsWithSize.show();
JavaPairRDD<String, Row> userIdKey = ratingsWithSize
.toJavaRDD()
.keyBy(new Function<Row, String>() {
@Override
public String call(Row v1) throws Exception {
return v1.getString(5);
}
});
JavaPairRDD<String, Row> userIdKey2 = ratingsWithSize
.toJavaRDD()
.keyBy(new Function<Row, String>() {
@Override
public String call(Row v1) throws Exception {
return v1.getString(5);
}
});
LOGGER.info("number of userIdKey {} userIdKey2 {}",userIdKey.count(), userIdKey2.count());
JavaPairRDD<String, Tuple2<Row, Row>> ratingPairs = userIdKey2
.join(userIdKey)
.filter(new Function<Tuple2<String, Tuple2<Row, Row>>, Boolean>() {
@Override
public Boolean call(Tuple2<String, Tuple2<Row, Row>> v1) throws Exception {
return v1._2()._1().getInt(2) < v1._2()._2().getInt(2);
}
});
LOGGER.info("number of rating pairs {}",ratingPairs.count());
// compute raw inputs to similarity metrics for each movie pair
JavaPairRDD<Tuple2<Integer, Integer>, Row> pairStats = ratingPairs.mapToPair((item)-> {
// this tuple contains videoIds of the pairs
Tuple2<Integer, Integer> videoIdPairs = new Tuple2<>(item._2()._1().getInt(1), item._2()._2().getInt(1));
Row row = RowFactory.create(
item._2()._1().getInt(2) * item._2()._2().getInt(2), // rating 1 * rating 2
item._2()._1().getInt(2), // rating movie 1
item._2()._2().getInt(2), // rating movie 2
Math.pow(item._2()._1().getInt(2), 2), // square of rating movie 1
Math.pow(item._2()._2().getInt(2), 2), // square of rating movie 2
item._2()._1().getInt(3), // number of raters movie 1
item._2()._2().getInt(3)); // number of raters movie 2
return new Tuple2<>(videoIdPairs, row);
});
LOGGER.info("number of pair stats {}",pairStats.count());
JavaPairRDD<Tuple2<Integer, Integer>, Row> vectorCals = pairStats.groupByKey().mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Iterable<Row>>, Tuple2<Integer, Integer>, Row>() {
@Override
public Tuple2<Tuple2<Integer, Integer>, Row> call(Tuple2<Tuple2<Integer, Integer>, Iterable<Row>> data) throws Exception {
Tuple2<Integer, Integer> key = data._1();
Iterator<Row> vals = data._2().iterator();
int size=0, dotProduct=0, ratingSum=0, rating2Sum=0, ratingSq=0, rating2Sq=0, numRaters=0, numRaters2=0;
while(vals.hasNext()){
Row row=vals.next();
size++;
dotProduct += row.getInt(0);
ratingSum += row.getInt(1);
rating2Sum += row.getInt(2);
ratingSq += row.getInt(3);
rating2Sq += row.getInt(4);
numRaters = Math.max(numRaters, row.getInt(5));
numRaters2 = Math.max(numRaters2, row.getInt(6));
}
return new Tuple2<Tuple2<Integer, Integer>, Row>(key, RowFactory.create(size, dotProduct, ratingSum, rating2Sum, ratingSq, rating2Sq, numRaters, numRaters2));
}
});
LOGGER.info("number of vector calcs {}",vectorCals.count());
// compute similarity metrics for each movie pair
JavaPairRDD<Integer, Row> similarities = vectorCals.mapToPair((data)->{
Tuple2<Integer, Integer> key = data._1();
Row row = data._2();
Double size = row.getDouble(0),
dotProduct = row.getDouble(1),
ratingSum = row.getDouble(2),
rating2Sum = row.getDouble(3),
ratingNormSq = row.getDouble(4),
rating2NormSq = row.getDouble(5),
numRaters = row.getDouble(6),
numRaters2 = row.getDouble(7);
double corr = correlation(size, dotProduct, ratingSum, rating2Sum, ratingNormSq, rating2NormSq);
double regCorr = regularizedCorrelation(size, dotProduct, ratingSum, rating2Sum, ratingNormSq, rating2NormSq, PRIOR_COUNT, PRIOR_CORRELATION);
double cosSim = cosineSimilarity(dotProduct, Math.sqrt(ratingNormSq), Math.sqrt(rating2NormSq));
double jaccard = jaccardSimilarity(size, numRaters, numRaters2);
return new Tuple2<Integer, Row>(key._1(), RowFactory.create(key._1(), key._2(), corr, regCorr, cosSim, jaccard));
});
similarities.saveAsTextFile(outDir+"/similarities");
}
// *************************
// * SIMILARITY MEASURES
// *************************
/**
* The correlation between two vectors A, B is
* cov(A, B) / (stdDev(A) * stdDev(B))
*
* This is equivalent to
* [n * dotProduct(A, B) - sum(A) * sum(B)] /
* sqrt{ [n * norm(A)^2 - sum(A)^2] [n * norm(B)^2 - sum(B)^2] }
*/
private static Double correlation(Double size, Double dotProduct, Double ratingSum,
Double rating2Sum, Double ratingNormSq, Double rating2NormSq){
double numerator = size * dotProduct - ratingSum * rating2Sum;
double denominator = Math.sqrt(size * ratingNormSq - ratingSum * ratingSum) *
Math.sqrt(size * rating2NormSq - rating2Sum * rating2Sum);
return numerator / denominator;
}
/**
* Regularize correlation by adding virtual pseudocounts over a prior:
* RegularizedCorrelation = w * ActualCorrelation + (1 - w) * PriorCorrelation
* where w = # actualPairs / (# actualPairs + # virtualPairs).
*/
private static Double regularizedCorrelation(Double size , Double dotProduct, Double ratingSum,
Double rating2Sum, Double ratingNormSq, Double rating2NormSq,
Double virtualCount, Double priorCorrelation) {
double unregularizedCorrelation = correlation(size, dotProduct, ratingSum, rating2Sum, ratingNormSq, rating2NormSq);
double w = size / (size + virtualCount);
return w * unregularizedCorrelation + (1 - w) * priorCorrelation;
}
/**
* The cosine similarity between two vectors A, B is
* dotProduct(A, B) / (norm(A) * norm(B))
*/
private static Double cosineSimilarity(Double dotProduct, Double ratingNorm, Double rating2Norm ) {
return dotProduct / (ratingNorm * rating2Norm);
}
/**
* The Jaccard Similarity between two sets A, B is
* |Intersection(A, B)| / |Union(A, B)|
*/
private static Double jaccardSimilarity(Double usersInCommon , Double totalUsers1, Double totalUsers2){
Double union = totalUsers1 + totalUsers2 - usersInCommon;
return usersInCommon / union;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment