Skip to content

Instantly share code, notes, and snippets.

@jossef
Last active January 3, 2019 03:44
Show Gist options
  • Save jossef/e6c8fc0c31f0c2bf036a to your computer and use it in GitHub Desktop.
Save jossef/e6c8fc0c31f0c2bf036a to your computer and use it in GitHub Desktop.
package jossef.cli.mahout;
import au.com.bytecode.opencsv.CSVReader;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.FileReader;
import java.io.IOException;
import java.text.DateFormat;
import java.text.NumberFormat;
import java.text.ParseException;
import java.text.ParsePosition;
import java.text.SimpleDateFormat;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
public class CsvToVectors
{
private static long wordCount = 1;
private static final Map<String, Long> words = Maps.newHashMap();
private final String csvPath;
public CsvToVectors(String csvPath)
{
this.csvPath = csvPath;
}
public List<MahoutVector> vectorize() throws IOException
{
List<MahoutVector> vectors = Lists.newArrayList();
// Iterate the CSV records
CSVReader reader = new CSVReader(new FileReader(this.csvPath));
String[] line;
try
{
while ((line = reader.readNext()) != null)
{
Vector vector = new RandomAccessSparseVector(line.length-1, line.length-1);
int rowIndex = 0;
// @attribute 'duration' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'protocol_type' {'tcp','udp', 'icmp'}
vector.set(rowIndex, processString(line[rowIndex]));
rowIndex++;
// @attribute 'service' {'aol', 'auth', 'bgp', 'courier', 'csnet_ns', 'ctf', 'daytime', 'discard', 'domain', 'domain_u', 'echo', 'eco_i', 'ecr_i', 'efs', 'exec', 'finger', 'ftp', 'ftp_data', 'gopher', 'harvest', 'hostnames', 'http', 'http_2784', 'http_443', 'http_8001', 'imap4', 'IRC', 'iso_tsap', 'klogin', 'kshell', 'ldap', 'link', 'login', 'mtp', 'name', 'netbios_dgm', 'netbios_ns', 'netbios_ssn', 'netstat', 'nnsp', 'nntp', 'ntp_u', 'other', 'pm_dump', 'pop_2', 'pop_3', 'printer', 'private', 'red_i', 'remote_job', 'rje', 'shell', 'smtp', 'sql_net', 'ssh', 'sunrpc', 'supdup', 'systat', 'telnet', 'tftp_u', 'tim_i', 'time', 'urh_i', 'urp_i', 'uucp', 'uucp_path', 'vmnet', 'whois', 'X11', 'Z39_50'}
vector.set(rowIndex, processString(line[rowIndex]));
rowIndex++;
// @attribute 'flag' { 'OTH', 'REJ', 'RSTO', 'RSTOS0', 'RSTR', 'S0', 'S1', 'S2', 'S3', 'SF', 'SH' }
vector.set(rowIndex, processString(line[rowIndex]));
rowIndex++;
// @attribute 'src_bytes' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_bytes' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'land' {'0', '1'}
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'wrong_fragment' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'urgent' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'hot' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'num_failed_logins' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'logged_in' {'0', '1'}
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'num_compromised' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'root_shell' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'su_attempted' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'num_root' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'num_file_creations' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'num_shells' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'num_access_files' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'num_outbound_cmds' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'is_host_login' {'0', '1'}
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'is_guest_login' {'0', '1'}
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'count' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'srv_count' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'serror_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'srv_serror_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'rerror_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'srv_rerror_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'same_srv_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'diff_srv_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'srv_diff_host_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_host_count' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_host_srv_count' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_host_same_srv_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_host_diff_srv_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_host_same_src_port_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_host_srv_diff_host_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_host_serror_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_host_srv_serror_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_host_rerror_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
// @attribute 'dst_host_srv_rerror_rate' real
vector.set(rowIndex, processNumeric(line[rowIndex]));
rowIndex++;
rowIndex=41;
// @attribute 'class' {'normal', 'anomaly'}
String classifier = line[rowIndex];
MahoutVector mahoutVector = new MahoutVector();
mahoutVector.classifier = classifier;
mahoutVector.vector = vector;
vectors.add(mahoutVector);
}
return vectors;
}
finally
{
reader.close();
}
}
// Not sure how scalable this is going to be
protected double processString(String data)
{
Long theLong = words.get(data);
if (theLong == null)
{
theLong = wordCount++;
words.put(data, theLong);
}
return theLong;
}
protected double processNumeric(String data)
{
Double d = Double.NaN;
if (isNumeric(data))
{
d = Double.parseDouble(data);
}
return d;
}
public static boolean isNumeric(String str)
{
NumberFormat formatter = NumberFormat.getInstance();
ParsePosition parsePosition = new ParsePosition(0);
formatter.parse(str, parsePosition);
return str.length() == parsePosition.getIndex();
}
}
We can make this file beautiful and searchable if this error is corrected: It looks like row 4 should actually have 42 columns, instead of 1. in line 3.
0,tcp,private,REJ,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,229,10,0.00,0.00,1.00,1.00,0.04,0.06,0.00,255,10,0.04,0.06,0.00,0.00,0.00,0.00,1.00,1.00,anomaly
0,tcp,private,REJ,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,136,1,0.00,0.00,1.00,1.00,0.01,0.06,0.00,255,1,0.00,0.06,0.00,0.00,0.00,0.00,1.00,1.00,anomaly
2,tcp,ftp_data,SF,12983,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0.00,0.00,0.00,0.00,1.00,0.00,0.00,134,86,0.61,0.04,0.61,0.02,0.00,0.00,0.00,0.00,normal
...
...
...
We can make this file beautiful and searchable if this error is corrected: It looks like row 4 should actually have 42 columns, instead of 1. in line 3.
0,tcp,ftp_data,SF,491,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0.00,0.00,0.00,0.00,1.00,0.00,0.00,150,25,0.17,0.03,0.17,0.00,0.00,0.00,0.05,0.00,normal
0,udp,other,SF,146,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,13,1,0.00,0.00,0.00,0.00,0.08,0.15,0.00,255,1,0.00,0.60,0.88,0.00,0.00,0.00,0.00,0.00,normal
0,tcp,private,S0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,123,6,1.00,1.00,0.00,0.00,0.05,0.07,0.00,255,26,0.10,0.05,0.00,0.00,1.00,1.00,0.00,0.00,anomaly
...
...
...
package jossef.cli.mahout;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.hadoop.io.SequenceFile.Writer;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class MahoutTest
{
private static final Logger logger = LoggerFactory.getLogger(MahoutTest.class);
public static void csv2vectors() throws IOException
{
Configuration conf = new Configuration();
FileSystem fs = FileSystem.getLocal(conf);
Path seqFilePath = new Path("/root/workspace/jossef.cli/kdd/kdd-seq");
fs.delete(seqFilePath,false);
SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, seqFilePath, Text.class, VectorWritable.class);
String csvPath = "/home/jossef/Desktop/kdd/KDDTrain+.csv";
try
{
CsvToVectors csvToVectors = new CsvToVectors(csvPath);
List<MahoutVector> vectors = csvToVectors.vectorize();
// Init the labels
for (MahoutVector vector : vectors)
{
VectorWritable vectorWritable = new VectorWritable();
vectorWritable.set(vector.vector);
writer.append(new Text("/" + vector.classifier + "/"), vectorWritable);
}
}
finally
{
writer.close();
}
}
public static void train() throws Throwable
{
Configuration conf = new Configuration();
FileSystem fs = FileSystem.getLocal(conf);
Path seqFilePath = new Path("/root/workspace/jossef.cli/kdd/kdd-seq");
TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
trainNaiveBayes.setConf(conf);
String sequenceFile = "/root/workspace/jossef.cli/kdd/kdd-seq";
String outputDirectory = "/root/workspace/jossef.cli/output";
String tempDirectory = "/root/workspace/jossef.cli/temp";
fs.delete(new Path(outputDirectory),true);
fs.delete(new Path(tempDirectory),true);
trainNaiveBayes.run(new String[] { "--input", sequenceFile, "--output", outputDirectory, "-el", "--overwrite", "--tempDir", tempDirectory });
// Train the classifier
NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDirectory), conf);
System.out.println("features: " + naiveBayesModel.numFeatures());
System.out.println("labels: " + naiveBayesModel.numLabels());
AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel);
String csvPath = "/home/jossef/Desktop/kdd/KDDTest+.csv";
CsvToVectors csvToVectors = new CsvToVectors(csvPath);
List<MahoutVector> vectors = csvToVectors.vectorize();
int total = 0;
int success = 0;
for (MahoutVector mahoutVector : vectors)
{
Vector prediction = classifier.classifyFull(mahoutVector.vector);
// They sorted alphabetically
// 0 = anomaly, 1 = normal (because 'anomaly' > 'normal')
double anomaly = prediction.get(0);
double normal = prediction.get(1);
String predictedClass = "anomaly";
if (normal > anomaly)
{
predictedClass="normal";
}
if (predictedClass.equals(mahoutVector.classifier))
{
success++;
}
total ++;
}
System.out.println(total + " : " + success + " : " + (total - success) + " " + ((double)success/total));
//StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier();
}
public static void main(String[] args) throws Throwable
{
csv2vectors();
train();
}
}
package jossef.cli.mahout;
import org.apache.mahout.math.Vector;
public class MahoutVector
{
public String classifier;
public Vector vector;
}
@RangerWolf
Copy link

Hi Jossef,
I found your gist from apache mail list and run your sample with mahout 0.9 + hadoop 1.2.1.
I also used kdd cup 1999 data set, both full and 10% as training sample.
But I could only get 62~65% of accuracy.

Could you share your results? and do you have any suggestion on improve it?

Thanks!
sorry for my bad english~

@amadeimatti
Copy link

HI Jossef! I'm new on Mahout. I have a file .csv with numerical data and a class. I'm trying to use your class on Github (with hadoop 1.2.X and mahout 0.9) but it gives me an error File temp/labelIndex does not exist... Do you know why? Thank you very much for your work!
Matteo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment