Skip to content

Instantly share code, notes, and snippets.

@yaravind
Forked from albrzykowski/TextClassification.java
Created April 28, 2020 18:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yaravind/4c2689164ee43d0d34bdaa568f5c6229 to your computer and use it in GitHub Desktop.
Save yaravind/4c2689164ee43d0d34bdaa568f5c6229 to your computer and use it in GitHub Desktop.
import java.util.Arrays;
import java.util.List;
import org.apache.hadoop.yarn.webapp.hamlet.HamletSpec.P;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RelationalGroupedDataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.catalyst.expressions.Randn;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.expressions.WindowSpec;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.netlib.util.doubleW;
import breeze.linalg.randn;
import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.feature.StopWordsRemover;
import org.apache.spark.ml.feature.StringIndexer;
import static org.apache.spark.sql.functions.*;
public class App {
public static void main( String[] args ) {
SparkSession spark = SparkSession
.builder()
.appName("Java Spark SQL Example")
.getOrCreate();
StructType schema = new StructType()
.add("word", "string")
.add("polarity", "double")
.add("category", "string");
Dataset<Row> df = spark.read()
.option("mode", "DROPMALFORMED")
.option("delimiter", "\t")
.option("header", "true")
.schema(schema)
.csv("src/main/resources/SEL-utf-8.txt");
df.show(20);
Dataset<Row>[] split = df.orderBy(rand()).randomSplit(new double[] {0.7, 0.3});
Dataset<Row> training = split[0];
Dataset<Row> test = split[1];
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndexed");
Tokenizer tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("tokens");
StopWordsRemover stopWordsRemover = new StopWordsRemover()
.setInputCol("tokens")
.setOutputCol("cleardFromSopwords")
.setStopWords(StopWordsRemover.loadDefaultStopWords("english"));
HashingTF hashingTF = new HashingTF()
.setInputCol("cleardFromSopwords")
.setOutputCol("rawFeatures")
.setNumFeatures(50000);
IDF idf = new IDF()
.setInputCol("rawFeatures")
.setOutputCol("features");
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setFamily("multinomial")
.setLabelCol("labelIndexed");
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {indexer, tokenizer, stopWordsRemover, hashingTF, idf, lr});
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(lr.maxIter(), new int[] { 10, 20 })
.addGrid(lr.regParam(), new double[] { 0.1, 1.0 })
.addGrid(lr.elasticNetParam(), new double[] { 0.7 })
.addGrid(hashingTF.numFeatures(), new int[] {50000})
.build();
MulticlassClassificationEvaluator mce = new MulticlassClassificationEvaluator()
.setLabelCol("labelIndexed")
.setPredictionCol("prediction")
.setMetricName("weightedPrecision");
CrossValidator validator = new CrossValidator()
.setNumFolds(2)
.setEstimator(pipeline)
.setEvaluator(mce)
.setEstimatorParamMaps(paramGrid);
PipelineModel model = (PipelineModel) validator.fit(training).bestModel();
try {
model.save("src/main/resources/model");
} catch(Exception e) {}
Dataset<Row> predictions = model.transform(test);
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("labelIndexed")
.setPredictionCol("prediction")
.setMetricName("weightedPrecision");
double accuracy = evaluator.evaluate(predictions);
predictions
.withColumn("label", new Column("label"))
.withColumn("labelIndexed", new Column("labelIndexed"))
.withColumn("prediction", new Column("prediction"))
.withColumn("text", new Column("text"))
.select("label", "prediction", "labelIndexed", "text")
.show(500);
System.out.println("Weighted precision: " + accuracy);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment