Last active
August 29, 2015 14:03
-
-
Save reachbach/d8977c8eb5f71f889301 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Code: | |
package com.acme; | |
import com.beust.jcommander.JCommander; | |
import com.beust.jcommander.Parameter; | |
import com.beust.jcommander.ParameterException; | |
import java.io.File; | |
import org.apache.log4j.Logger; | |
import org.apache.spark.SparkConf; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.api.java.JavaPairRDD; | |
import org.apache.spark.api.java.JavaSparkContext; | |
import org.apache.spark.api.java.function.Function; | |
import org.apache.spark.api.java.function.PairFunction; | |
import org.apache.spark.api.java.function.Function2; | |
import scala.Tuple2; | |
import scala.Tuple4; | |
public class Test { | |
@Parameter(required = true, names = "-sparkMaster", description = "The url of the spark master in the form spark://host:port", arity = 1) | |
private String sparkMasterURL; | |
@Parameter(names = "-sparkHome", required = true, description = "Location of the local spark installation directory", arity = 1) | |
private String sparkHome; | |
@Parameter(names = "-dataFile", required = true, description = "The local file to pick up input data from", arity = 1) | |
private String dataFile; | |
@Parameter(names = "-numPartitions", required = false, description = "The number of spark data partitions to split across", arity = 1) | |
private int numPartitions = 80; | |
@Parameter(names = "-jarFileLoc", required = true, description = "The location of the jar file containing the application code with its non-spark dependencies", arity = 1) | |
private String jarFileLoc; | |
private static final Logger logger = Logger.getLogger(Test.class); | |
public static void main(String[] args) throws Exception { | |
Test regressionDemo = new Test(); | |
JCommander cmdParser = new JCommander(regressionDemo); | |
try { | |
cmdParser.parse(args); | |
regressionDemo.train(); | |
} catch (ParameterException exception) { | |
System.err.println("Exception parsing one or more command line arguments " + exception.getMessage()); | |
cmdParser.usage(); | |
} | |
} | |
public void train() throws Exception { | |
SparkConf conf = new SparkConf(); | |
conf = conf.setMaster(sparkMasterURL).setAppName("SparkRegressionExample").setSparkHome(sparkHome).setJars(new String[]{jarFileLoc}); | |
conf.set("spark.executor.memory", "1500M"); | |
conf.set("spark.default.parallelism", "1"); | |
conf.set("spark.cores.max", "20"); | |
conf.set("spark.storage.memoryFraction", "0.9"); | |
JavaSparkContext ctx = new JavaSparkContext(conf); | |
JavaRDD<String> rawData = createRawRDD(ctx, new File(dataFile)); | |
LineSplitter splitter = new LineSplitter(); | |
JavaRDD<Tuple4<Long, Long, Integer, Integer>> parsedRecords = rawData.map(splitter); | |
logger.info("Raw records size: " + parsedRecords.count()); | |
JavaRDD<Tuple4<Long, Long, Integer, Integer>> reducedTupleRecords = parsedRecords.coalesce(numPartitions); | |
logger.info("Filtered records size: " + reducedTupleRecords.count()); | |
TupleToPairMapper tupleToPairMapper = new TupleToPairMapper(); | |
JavaPairRDD<Tuple2<Long, Long>, Tuple2<Integer, Integer>> recordsByTuple = reducedTupleRecords.mapToPair(tupleToPairMapper); | |
logger.info("Size of records reduced by tuple: " + recordsByTuple.count()); | |
logger.info("Shutting down context..."); | |
ctx.stop(); | |
} | |
private static JavaRDD<String> createRawRDD(JavaSparkContext ctx, File file) { | |
JavaRDD<String> rawData = ctx.textFile(file.getAbsolutePath()); | |
return rawData; | |
} | |
private static class TupleToPairMapper implements PairFunction<Tuple4<Long, Long, Integer, Integer>, Tuple2<Long, Long>, Tuple2<Integer, Integer>> { | |
@Override | |
public Tuple2<Tuple2<Long, Long>, Tuple2<Integer, Integer>> call(Tuple4<Long, Long, Integer, Integer> t) throws Exception { | |
return new Tuple2<Tuple2<Long, Long>, Tuple2<Integer, Integer>>(new Tuple2<Long, Long>(t._1(), t._2()), new Tuple2<Integer, Integer>(t._3(), t._4())); | |
} | |
public TupleToPairMapper() { | |
} | |
} | |
private static class KVReducer implements Function2<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> { | |
@Override | |
public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> t1, Tuple2<Integer, Integer> t2) throws Exception { | |
return new Tuple2<Integer, Integer>(t1._1() + t2._1(), t1._2() + t2._2()); | |
} | |
public KVReducer() { | |
} | |
} | |
private static class LineSplitter implements Function<String, Tuple4<Long, Long, Integer, Integer>> { | |
@Override | |
public Tuple4<Long, Long, Integer, Integer> call(String line) throws Exception { | |
String[] dataPoint = line.trim().split("\\t"); | |
String f = dataPoint[0]; | |
String l1 = dataPoint[1]; | |
String[] fArr = f.split("#"); | |
Long x = Long.valueOf(fArr[0].trim()); | |
Long y = Long.valueOf(fArr[1].trim()); | |
String[] lArr = l1.split("#"); | |
int c = Integer.valueOf(lArr[2].trim()); | |
int r = Integer.valueOf(lArr[1].trim()); | |
Tuple4<Long, Long, Integer, Integer> tuple = new Tuple4<Long, Long, Integer, Integer>(x, y, r, c); | |
return tuple; | |
} | |
} | |
} | |
The data text file contains a single line : | |
1028#29922#89575#kuv6kvghgk1337d86d2111774 1#1#0 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment