Skip to content

Instantly share code, notes, and snippets.

@souravsingh
Created February 2, 2017 17:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save souravsingh/f78e795360154053fa85d55b279f1965 to your computer and use it in GitHub Desktop.
Save souravsingh/f78e795360154053fa85d55b279f1965 to your computer and use it in GitHub Desktop.
Decision Tree in Spark
from pyspark.mllib.regression import LabeledPoint, DecisionTree, DecisionTreeModel
import numpy as np
from time import time
import urllib
def create_labeled_point(line_split):
clean_line_split = line_split[0:41]
try:
clean_line_split[1] = protocols.index(clean_line_split[1])
except:
clean_line_split[1] = len(protocols)
try:
clean_line_split[2] = services.index(clean_line_split[2])
except:
clean_line_split[2] = len(services)
try:
clean_line_split[3] = flags.index(clean_line_split[3])
except:
clean_line_split[3] = len(flags)
attack = 1.0
if line_split[41]=='normal.':
attack = 0.0
return LabeledPoint(attack, array([float(x) for x in clean_line_split]))
f1 = urllib.urlretrieve ("http://kdd.ics.uci.edu/databases/kddcup99/kddcup.data.gz", "kddcup.data.gz")
f2 = urllib.urlretrieve("http://kdd.ics.uci.edu/databases/kddcup99/corrected.gz", "corrected.gz")
data_file_kdd = "./kddcup.data.gz"
raw_data = sc.textFile(data_file_kdd)
test_data_file_kdd = "./corrected.gz"
test_raw_data = sc.textFile(test_data_file_kdd)
csv_data = raw_data.map(lambda x: x.split(","))
test_csv_data = test_raw_data.map(lambda x: x.split(","))
protocols = csv_data.map(lambda x: x[1]).distinct().collect()
services = csv_data.map(lambda x: x[2]).distinct().collect()
flags = csv_data.map(lambda x: x[3]).distinct().collect()
training_data = csv_data.map(create_labeled_point)
test_data = test_csv_data.map(create_labeled_point)
tree_model = DecisionTree.trainClassifier(training_data, numClasses=2,
categoricalFeaturesInfo={1: len(protocols), 2: len(services), 3: len(flags)},
impurity='gini', maxDepth=4, maxBins=100)
prediction = tree_model.predict(test_data.map(lambda p: p.features))
labels_and_predis = test_data.map(lambda p: p.label).zip(prediction)
test_accuracy = labels_and_predis.filter(lambda (v, p): v == p).count() / float(test_data.count())
print "Accuracy on test data is: {}".format(round(test_accuracy,4))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment