Last active
January 3, 2019 03:44
-
-
Save jossef/e6c8fc0c31f0c2bf036a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package jossef.cli.mahout; | |
import au.com.bytecode.opencsv.CSVReader; | |
import com.google.common.collect.Lists; | |
import com.google.common.collect.Maps; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import java.text.DateFormat; | |
import java.text.NumberFormat; | |
import java.text.ParseException; | |
import java.text.ParsePosition; | |
import java.text.SimpleDateFormat; | |
import java.util.Collections; | |
import java.util.Date; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Locale; | |
import java.util.Map; | |
import java.util.regex.Pattern; | |
import org.apache.mahout.math.DenseVector; | |
import org.apache.mahout.math.RandomAccessSparseVector; | |
import org.apache.mahout.math.SequentialAccessSparseVector; | |
import org.apache.mahout.math.Vector; | |
public class CsvToVectors | |
{ | |
private static long wordCount = 1; | |
private static final Map<String, Long> words = Maps.newHashMap(); | |
private final String csvPath; | |
public CsvToVectors(String csvPath) | |
{ | |
this.csvPath = csvPath; | |
} | |
public List<MahoutVector> vectorize() throws IOException | |
{ | |
List<MahoutVector> vectors = Lists.newArrayList(); | |
// Iterate the CSV records | |
CSVReader reader = new CSVReader(new FileReader(this.csvPath)); | |
String[] line; | |
try | |
{ | |
while ((line = reader.readNext()) != null) | |
{ | |
Vector vector = new RandomAccessSparseVector(line.length-1, line.length-1); | |
int rowIndex = 0; | |
// @attribute 'duration' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'protocol_type' {'tcp','udp', 'icmp'} | |
vector.set(rowIndex, processString(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'service' {'aol', 'auth', 'bgp', 'courier', 'csnet_ns', 'ctf', 'daytime', 'discard', 'domain', 'domain_u', 'echo', 'eco_i', 'ecr_i', 'efs', 'exec', 'finger', 'ftp', 'ftp_data', 'gopher', 'harvest', 'hostnames', 'http', 'http_2784', 'http_443', 'http_8001', 'imap4', 'IRC', 'iso_tsap', 'klogin', 'kshell', 'ldap', 'link', 'login', 'mtp', 'name', 'netbios_dgm', 'netbios_ns', 'netbios_ssn', 'netstat', 'nnsp', 'nntp', 'ntp_u', 'other', 'pm_dump', 'pop_2', 'pop_3', 'printer', 'private', 'red_i', 'remote_job', 'rje', 'shell', 'smtp', 'sql_net', 'ssh', 'sunrpc', 'supdup', 'systat', 'telnet', 'tftp_u', 'tim_i', 'time', 'urh_i', 'urp_i', 'uucp', 'uucp_path', 'vmnet', 'whois', 'X11', 'Z39_50'} | |
vector.set(rowIndex, processString(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'flag' { 'OTH', 'REJ', 'RSTO', 'RSTOS0', 'RSTR', 'S0', 'S1', 'S2', 'S3', 'SF', 'SH' } | |
vector.set(rowIndex, processString(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'src_bytes' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_bytes' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'land' {'0', '1'} | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'wrong_fragment' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'urgent' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'hot' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'num_failed_logins' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'logged_in' {'0', '1'} | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'num_compromised' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'root_shell' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'su_attempted' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'num_root' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'num_file_creations' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'num_shells' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'num_access_files' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'num_outbound_cmds' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'is_host_login' {'0', '1'} | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'is_guest_login' {'0', '1'} | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'count' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'srv_count' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'serror_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'srv_serror_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'rerror_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'srv_rerror_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'same_srv_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'diff_srv_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'srv_diff_host_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_host_count' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_host_srv_count' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_host_same_srv_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_host_diff_srv_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_host_same_src_port_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_host_srv_diff_host_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_host_serror_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_host_srv_serror_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_host_rerror_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
// @attribute 'dst_host_srv_rerror_rate' real | |
vector.set(rowIndex, processNumeric(line[rowIndex])); | |
rowIndex++; | |
rowIndex=41; | |
// @attribute 'class' {'normal', 'anomaly'} | |
String classifier = line[rowIndex]; | |
MahoutVector mahoutVector = new MahoutVector(); | |
mahoutVector.classifier = classifier; | |
mahoutVector.vector = vector; | |
vectors.add(mahoutVector); | |
} | |
return vectors; | |
} | |
finally | |
{ | |
reader.close(); | |
} | |
} | |
// Not sure how scalable this is going to be | |
protected double processString(String data) | |
{ | |
Long theLong = words.get(data); | |
if (theLong == null) | |
{ | |
theLong = wordCount++; | |
words.put(data, theLong); | |
} | |
return theLong; | |
} | |
protected double processNumeric(String data) | |
{ | |
Double d = Double.NaN; | |
if (isNumeric(data)) | |
{ | |
d = Double.parseDouble(data); | |
} | |
return d; | |
} | |
public static boolean isNumeric(String str) | |
{ | |
NumberFormat formatter = NumberFormat.getInstance(); | |
ParsePosition parsePosition = new ParsePosition(0); | |
formatter.parse(str, parsePosition); | |
return str.length() == parsePosition.getIndex(); | |
} | |
} |
We can make this file beautiful and searchable if this error is corrected: It looks like row 4 should actually have 42 columns, instead of 1. in line 3.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0,tcp,private,REJ,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,229,10,0.00,0.00,1.00,1.00,0.04,0.06,0.00,255,10,0.04,0.06,0.00,0.00,0.00,0.00,1.00,1.00,anomaly | |
0,tcp,private,REJ,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,136,1,0.00,0.00,1.00,1.00,0.01,0.06,0.00,255,1,0.00,0.06,0.00,0.00,0.00,0.00,1.00,1.00,anomaly | |
2,tcp,ftp_data,SF,12983,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0.00,0.00,0.00,0.00,1.00,0.00,0.00,134,86,0.61,0.04,0.61,0.02,0.00,0.00,0.00,0.00,normal | |
... | |
... | |
... |
We can make this file beautiful and searchable if this error is corrected: It looks like row 4 should actually have 42 columns, instead of 1. in line 3.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0,tcp,ftp_data,SF,491,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0.00,0.00,0.00,0.00,1.00,0.00,0.00,150,25,0.17,0.03,0.17,0.00,0.00,0.00,0.05,0.00,normal | |
0,udp,other,SF,146,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,13,1,0.00,0.00,0.00,0.00,0.08,0.15,0.00,255,1,0.00,0.60,0.88,0.00,0.00,0.00,0.00,0.00,normal | |
0,tcp,private,S0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,123,6,1.00,1.00,0.00,0.00,0.05,0.07,0.00,255,26,0.10,0.05,0.00,0.00,1.00,1.00,0.00,0.00,anomaly | |
... | |
... | |
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package jossef.cli.mahout; | |
import java.io.BufferedReader; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import java.util.List; | |
import org.apache.hadoop.conf.Configuration; | |
import org.apache.hadoop.fs.FileSystem; | |
import org.apache.hadoop.fs.Path; | |
import org.apache.hadoop.io.LongWritable; | |
import org.apache.hadoop.io.SequenceFile; | |
import org.apache.hadoop.io.Text; | |
import org.apache.mahout.math.RandomAccessSparseVector; | |
import org.apache.mahout.math.Vector; | |
import org.apache.mahout.math.VectorWritable; | |
import org.apache.hadoop.io.SequenceFile.Writer; | |
import org.apache.hadoop.fs.Path; | |
import org.apache.mahout.classifier.AbstractVectorClassifier; | |
import org.apache.mahout.classifier.ClassifierResult; | |
import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier; | |
import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; | |
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; | |
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier; | |
import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob; | |
import org.apache.mahout.common.HadoopUtil; | |
import org.apache.mahout.math.Vector; | |
import org.apache.mahout.math.VectorWritable; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
public class MahoutTest | |
{ | |
private static final Logger logger = LoggerFactory.getLogger(MahoutTest.class); | |
public static void csv2vectors() throws IOException | |
{ | |
Configuration conf = new Configuration(); | |
FileSystem fs = FileSystem.getLocal(conf); | |
Path seqFilePath = new Path("/root/workspace/jossef.cli/kdd/kdd-seq"); | |
fs.delete(seqFilePath,false); | |
SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, seqFilePath, Text.class, VectorWritable.class); | |
String csvPath = "/home/jossef/Desktop/kdd/KDDTrain+.csv"; | |
try | |
{ | |
CsvToVectors csvToVectors = new CsvToVectors(csvPath); | |
List<MahoutVector> vectors = csvToVectors.vectorize(); | |
// Init the labels | |
for (MahoutVector vector : vectors) | |
{ | |
VectorWritable vectorWritable = new VectorWritable(); | |
vectorWritable.set(vector.vector); | |
writer.append(new Text("/" + vector.classifier + "/"), vectorWritable); | |
} | |
} | |
finally | |
{ | |
writer.close(); | |
} | |
} | |
public static void train() throws Throwable | |
{ | |
Configuration conf = new Configuration(); | |
FileSystem fs = FileSystem.getLocal(conf); | |
Path seqFilePath = new Path("/root/workspace/jossef.cli/kdd/kdd-seq"); | |
TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob(); | |
trainNaiveBayes.setConf(conf); | |
String sequenceFile = "/root/workspace/jossef.cli/kdd/kdd-seq"; | |
String outputDirectory = "/root/workspace/jossef.cli/output"; | |
String tempDirectory = "/root/workspace/jossef.cli/temp"; | |
fs.delete(new Path(outputDirectory),true); | |
fs.delete(new Path(tempDirectory),true); | |
trainNaiveBayes.run(new String[] { "--input", sequenceFile, "--output", outputDirectory, "-el", "--overwrite", "--tempDir", tempDirectory }); | |
// Train the classifier | |
NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDirectory), conf); | |
System.out.println("features: " + naiveBayesModel.numFeatures()); | |
System.out.println("labels: " + naiveBayesModel.numLabels()); | |
AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel); | |
String csvPath = "/home/jossef/Desktop/kdd/KDDTest+.csv"; | |
CsvToVectors csvToVectors = new CsvToVectors(csvPath); | |
List<MahoutVector> vectors = csvToVectors.vectorize(); | |
int total = 0; | |
int success = 0; | |
for (MahoutVector mahoutVector : vectors) | |
{ | |
Vector prediction = classifier.classifyFull(mahoutVector.vector); | |
// They sorted alphabetically | |
// 0 = anomaly, 1 = normal (because 'anomaly' > 'normal') | |
double anomaly = prediction.get(0); | |
double normal = prediction.get(1); | |
String predictedClass = "anomaly"; | |
if (normal > anomaly) | |
{ | |
predictedClass="normal"; | |
} | |
if (predictedClass.equals(mahoutVector.classifier)) | |
{ | |
success++; | |
} | |
total ++; | |
} | |
System.out.println(total + " : " + success + " : " + (total - success) + " " + ((double)success/total)); | |
//StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(); | |
} | |
public static void main(String[] args) throws Throwable | |
{ | |
csv2vectors(); | |
train(); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package jossef.cli.mahout; | |
import org.apache.mahout.math.Vector; | |
public class MahoutVector | |
{ | |
public String classifier; | |
public Vector vector; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
HI Jossef! I'm new on Mahout. I have a file .csv with numerical data and a class. I'm trying to use your class on Github (with hadoop 1.2.X and mahout 0.9) but it gives me an error File temp/labelIndex does not exist... Do you know why? Thank you very much for your work!
Matteo