Skip to content

Instantly share code, notes, and snippets.

Created January 5, 2017 18:06
Show Gist options
  • Save Jeffwan/15465a4084081bd758a48febb0042daf to your computer and use it in GitHub Desktop.
Save Jeffwan/15465a4084081bd758a48febb0042daf to your computer and use it in GitHub Desktop.
Machine Learning Pipleline
import java.util.Arrays;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.functions;
* Step by step generated DataFrames are useless. This is just helpful for observe data.
* We have pipeline to help us chain all the stages together.
* @author jiashan
public class WikiPageClustering {
public static void main(String[] args) {
SparkConf sparkConf =
new SparkConf().setMaster("local[2]").setAppName(WikiPageClustering.class.getSimpleName());
SparkContext sc = new SparkContext(sparkConf);
SQLContext sqlContext = new SQLContext(sc);
DataFrame wikiDF ="location").cache();
// How to select all columns (String * and Column type)
DataFrame wikiLoweredDF ="text")).alias("lowerText"));
// Step 1: Tokenizer
RegexTokenizer tokenizer =
new RegexTokenizer().setInputCol("lowerText").setOutputCol("words").setPattern("\\W+");
// DataFrame wikiWordsDF = tokenizer.transform(wikiLoweredDF);
// Step 2: Remove Stop Words
StopWordsRemover remover = new StopWordsRemover().setInputCol("words").setOutputCol("noStopWords");
// DataFrame noStopWordsListDf = remover.transform(wikiWordsDF);
// Step 3: HashingTF
int numFeatures = 20000;
HashingTF hashingTF =
new HashingTF().setInputCol("noStopWords").setOutputCol("hashingTF").setNumFeatures(numFeatures);
// DataFrame featurizedDF = hashingTF.transform(noStopWordsListDf);
// Step 4: IDF
IDF idf = new IDF().setInputCol("hashingTF").setOutputCol("idf");
// IDFModel idfModel =;
// Step 5: Normalizer
Normalizer normalizer = new Normalizer().setInputCol("idf").setOutputCol("features");
// Step 6: KMeans
int numCluster = 100;
KMeans kmeans =
new KMeans().setFeaturesCol("features").setPredictionCol("prediction").setK(numCluster).setSeed(0);
// Step 7: ML Pipeline Training model.
List<PipelineStage> pipelineStages = Arrays.asList(tokenizer, remover, hashingTF, idf, normalizer, kmeans);
Pipeline pipeline = new Pipeline().setStages(pipelineStages.toArray(new Pipeline[] {}));
PipelineModel model =;
// TODO: store trained model and then we can reuse next time.
// Step 8: Use trained model to predict new data frames
DataFrame predictionDF = model.transform(wikiLoweredDF);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment