Skip to content

Instantly share code, notes, and snippets.

Created January 28, 2016 15:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/a6087c7527393207fac0 to your computer and use it in GitHub Desktop.
Save anonymous/a6087c7527393207fac0 to your computer and use it in GitHub Desktop.
Spark LDA prediction example
// example usage
//
// String test_doc = "Antwaan Randle El Comments on NFL's Impact on His Life, Choice to Play Football \n" + "Former NFL wide receiver Antwaan Randle El, who last played with the Pittsburgh Steelers in 2010, does not have fond memories of his playing days with the benefit of hindsight. \n" +
// "Speaking to the Pittsburgh Post-Gazette (via Jacob Emert of the Washington Post), Randle El said he wished he had played baseball coming out of college: \"If I could go back, I wouldn’t [play football]. I would play baseball. I got drafted by the Cubs in the 14th round, but I didn’t play baseball because of my parents. They made me go to school. Don’t get me wrong, I love the game of football. But right now, I could still be playing baseball.\"\n" +
// "Randle El said there are certain days he walks down stairs \"sideways\" because of the toll football took on him. \n" +
// "Randle El, who is 36 years old and spent nine seasons in the NFL with Pittsburgh and Washington—winning a Super Bowl with the Steelers after the 2005 season—told the Post-Gazette he has difficulty remembering things: ");
//
// Vector prediction = predictTopicsForDocument(sparkContext, ldaModel, test_doc);
// showTopTopicsForPrediction(prediction);
public Vector predictTopicsForDocument(JavaSparkContext sparkContext,
LocalLDAModel ldaModel,
String document) {
// normalize/stopword-filter/etc removed for brevity
//document = cleanContent(document.replace(" ", "+"));
// create an in-memory list/RDD of docs, size 1, for our test doc
List<String> docList = new ArrayList<>();
docList.add("test_url " + document);
JavaRDD<String> docsRDD = sparkContext.parallelize(docList,1);
DataFrame dataFrame = getDataFrameForDocsRDD(sparkContext, docsRDD);
CountVectorizerModel cvModel = getCVModelForDataFrame(dataFrame);
JavaPairRDD<Long, Vector> indexedDocs = id2doc(cvModel, dataFrame);
Vector test_vector = indexedDocs.collect().get(0)._2();
return ldaModel.topicDistribution(test_vector);
}
public DataFrame getDataFrameForDocsRDD(JavaSparkContext sparkContext, JavaRDD<String> docs) {
SQLContext sqlContext = new SQLContext(sparkContext);
// parse the text input, which is in the format
// url(id) word1 word2... wordn
JavaRDD<Row> rowRDD = docs.map(
new Function<String, Row>() {
@Override
public Row call(String record) throws Exception {
int idxFirstSpace = record.indexOf(" ");
String id = record.substring(0, idxFirstSpace);
String[] words = record.substring(idxFirstSpace+1).split(" ");
return RowFactory.create(id, words);
}
}
);
return sqlContext.createDataFrame(rowRDD, getSchema());
}
public StructType getSchema() {
List<StructField> fields = new ArrayList<>();
fields.add(DataTypes.createStructField("id", DataTypes.StringType, false));
fields.add(new StructField("words", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()));
return DataTypes.createStructType(fields);
}
public CountVectorizerModel getCVModelForDataFrame(DataFrame dataFrame) {
// fit a CountVectorizerModel from the training/test corpus
return new CountVectorizer()
.setInputCol("words")
.setOutputCol("vector")
.fit(dataFrame);
}
public JavaPairRDD<Long, Vector> id2doc(CountVectorizerModel cvModel, DataFrame docs) {
// this seems inefficient; create a JavaRDD of Vectors from the transformed dataset
// as an intermediate step towards getting a JavaPairRDD below.
// I'm sure there's a way to do this all in one step.
JavaRDD<Vector> vecs = cvModel.transform(docs).toJavaRDD().map(
new Function<Row,Vector>() {
@Override
public Vector call(Row row) throws Exception {
return (Vector) row.getAs(2);
}
}
);
// Index documents with unique IDs
return JavaPairRDD.fromJavaRDD(vecs.zipWithIndex().map(
new Function<Tuple2<Vector, Long>, Tuple2<Long, Vector>>() {
public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
return doc_id.swap();
}
}
));
}
public void showTopTopicsForPrediction(Vector prediction) {
double[] probabilities = prediction.toArray();
ArrayIndexComparator comparator = new ArrayIndexComparator(probabilities);
Integer[] indexes = comparator.createIndexArray();
Arrays.sort(indexes, comparator);
int top = Math.min(10, indexes.length);
LOG.info("Top " + top + " topics:");
for(int i=0; i<top; i++) {
LOG.info("Topic " + indexes[i] + ": " + probabilities[indexes[i]]);
}
}
static class ArrayIndexComparator implements Comparator<Integer> {
private final double[] arr;
public ArrayIndexComparator(double[] arr) {
this.arr = arr;
}
public Integer[] createIndexArray() {
Integer[] indexes = new Integer[arr.length];
for (int i = 0; i < arr.length; i++) {
indexes[i] = i; // Autoboxing
}
return indexes;
}
@Override
public int compare(Integer index1, Integer index2) {
// Autounbox from Integer to int to use as array indexes
return -1 * Double.valueOf(arr[index1]).compareTo(arr[index2]);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment