Skip to content

Instantly share code, notes, and snippets.

@dapurv5
Created May 18, 2013 02:29
Show Gist options
  • Save dapurv5/5603016 to your computer and use it in GitHub Desktop.
Save dapurv5/5603016 to your computer and use it in GitHub Desktop.
Sparse Matrix Multiplication in Map Reduce
/**
* Copyright (c) 2013 Apurv Verma
*/
package org.anahata.hadoop.illustrations;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.anahata.commons.hadoop.io.IntInt;
import org.anahata.commons.hadoop.io.IntIntInt;
import org.anahata.commons.io.IOBoilerplate;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
public class SpMatrixMult {
private final static Log LOG = LogFactory.getLog(SpMatrixMult.class);
private final static String INPUT_PATH = "/tmp/matrix-mult/input.seq";
private final static String OUTPUT_PATH = "/tmp/matrix-mult/output";
private final static String L = "org.anahata.hadoop.illustrations.SpMatrixMult.L";
private final static String M = "org.anahata.hadoop.illustrations.SpMatrixMult.M";
private final static String N = "org.anahata.hadoop.illustrations.SpMatrixMult.N";
private final static Configuration conf = new Configuration();
public static class SpMatrixMultMapper extends Mapper<IntIntInt, IntWritable, IntIntInt, IntInt>{
private int n;
private int l;
private IntIntInt keyOut;
private IntInt valOut;
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
super.setup(context);
Configuration conf = context.getConfiguration();
n = conf.getInt(N, 0);
l = conf.getInt(L, 0);
keyOut = new IntIntInt();
valOut = new IntInt();
}
@Override
protected void map(IntIntInt key, IntWritable value, Context context)
throws IOException, InterruptedException {
int matrixId = key.getFirst();
int i = key.getSecond();
int j = key.getThird();
if(matrixId == 0){
for(int n_ = 0; n_ < n; n_++){
keyOut.set(2, i, n_);
valOut.set(j, value.get());
context.write(keyOut, valOut);
}
} else{
for(int l_ = 0; l_ < l; l_++){
keyOut.set(2, l_, j);
valOut.set(i, value.get());
context.write(keyOut, valOut);
}
}
}
}
public static class SpMatrixMultReducer extends Reducer<IntIntInt, IntInt, IntIntInt, IntWritable>{
private IntIntInt key;
private IntWritable val;
private Map<Integer, Integer> aggrMap;
@Override
protected void setup(Context context)
throws IOException, InterruptedException {
super.setup(context);
key = new IntIntInt();
val = new IntWritable();
aggrMap = new HashMap<>();
}
@Override
protected void reduce(IntIntInt key, Iterable<IntInt> values, Context context) throws IOException,
InterruptedException {
for(IntInt ii: values){
int a = 1;
if(aggrMap.get(ii.getFirst()) != null){
a = aggrMap.get(ii.getFirst());
}
aggrMap.put(ii.getFirst(), a * ii.getSecond());
}
int res = 0;
for(Integer index: aggrMap.keySet()){
res += aggrMap.get(index);
}
val.set(res);
context.write(key, val);
aggrMap.clear();
}
}
public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
int[][] A = {
{0, 1, 2},
{2, 1, 3}
};
int[][] B = {
{1, 2},
{2, 0},
{1, 4}
};
assert(A[0].length == B.length);
//A (L cross M), B(M cross N)
writeInput(A, B);
conf.setInt(L, A.length);
conf.setInt(M, A[0].length);
conf.setInt(N, B[0].length);
Job job = new Job(conf);
job.setJarByClass(SpMatrixMult.class);
job.setJobName("Matrix Multiplication");
job.setMapperClass(SpMatrixMultMapper.class);
job.setReducerClass(SpMatrixMultReducer.class);
SequenceFileInputFormat.addInputPaths(job, INPUT_PATH);
SequenceFileOutputFormat.setOutputPath(job, new Path(OUTPUT_PATH));
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setMapOutputKeyClass(IntIntInt.class);
job.setMapOutputValueClass(IntInt.class);
job.setOutputKeyClass(IntIntInt.class);
job.setOutputValueClass(IntWritable.class);
job.waitForCompletion(true);
writeOutput();
cleanUp();
}
private static void writeInput(int[][] A, int[][] B){
IntIntInt key = new IntIntInt();
IntWritable val = new IntWritable();
SequenceFile.Writer writer = null;
try {
FileSystem fs = FileSystem.get(conf);
writer = SequenceFile.createWriter(fs, conf, new Path(INPUT_PATH),
IntIntInt.class, IntWritable.class);
for(int i = 0; i < A.length; i++){
for(int j = 0; j < A[0].length; j++){
key.set(0, i, j); //0 specifies first matrix
val.set(A[i][j]);
writer.append(key, val);
}
}
for(int i = 0; i < B.length; i++){
for(int j = 0; j < B[0].length; j++){
key.set(1, i, j); //1 specifies first matrix
val.set(B[i][j]);
writer.append(key, val);
}
}
} catch (IOException e) {
LOG.error("Could not write the input", e);
} finally{
IOBoilerplate.closeGracefully(writer);
}
}
private static void writeOutput(){
IntIntInt key = new IntIntInt();
IntWritable val = new IntWritable();
SequenceFile.Reader reader = null;
try {
FileSystem fs = FileSystem.get(conf);
FileStatus[] status = fs.listStatus(new Path(OUTPUT_PATH));
for(FileStatus stat : status){
if(stat.isDir() || stat.getLen() == 0){continue;}
reader = new SequenceFile.Reader(fs, stat.getPath(), conf);
while(reader.next(key, val)){
System.out.println(key+" __ "+val);
}
reader.close();
}
} catch (IOException e) {
LOG.error("Could not fetch result", e);
} finally{
IOBoilerplate.closeGracefully(reader);
}
}
private static void cleanUp(){
try {
FileSystem fs = FileSystem.get(conf);
fs.delete(new Path(OUTPUT_PATH), true);
} catch (IOException e) {
LOG.error("Could not cleanup files", e);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment