Skip to content

Instantly share code, notes, and snippets.

@vinaychittora
Last active August 29, 2015 13:59
Show Gist options
  • Save vinaychittora/10828300 to your computer and use it in GitHub Desktop.
Save vinaychittora/10828300 to your computer and use it in GitHub Desktop.
Decision Tree Using SciKit
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.externals.six import StringIO
import os
import pydot
import csv
import re
class LoadData():
def __init__(self):
self.readfile = csv.DictReader(open(os.getcwd()+"/segmented_data.csv", 'r'), delimiter=",", quotechar='"', dialect=csv.excel_tab)
self.threshold = 1000000
self.segments = {'HOUSE':1, 'BUILDING':2, 'LOCALITY':3, 'LANDMARK':4}
self.records = []
self.X = []
self.Y = []
self.count = {'0':0, '1':0, '2':0, '3':0, '4':0, '5':0, '6':0}
def load(self):
for i, row in enumerate(self.readfile):
parity = 0
for segment in self.segments.keys():
if len(row[segment])>=1:
parity += 1
if parity == len(self.segments.keys()):
self.records.append(row)
def feature_extraction(self):
for record in self.records :
for segment in self.segments.keys():
sample = record[segment]
features = [0 for i in range(7)]
# only alpha
if re.match("^[a-zA-Z.\-\ ]*$" , sample):
features[0] = 1
self.count['0'] += 1
# only numeric
if re.match('^[0-9\\\/\.\ \-]*$', sample) :
features[1] = 1
self.count['1'] += 1
# alpha numeric
if re.match('^\W+$', sample) :
features[2] = 1
self.count['2'] += 1
# contains house no
if sample.lower().find('house') != -1 or sample.lower().find('h.no.') !=-1:
features[3] = 1
self.count['3'] += 1
# contains building name
if sample.lower().find('tower') != -1 or sample.lower().find('apartment') !=-1 or sample.lower().find('building') !=-1 or sample.lower().find('hostal') !=-1 :
features[4] = 1
self.count['4'] += 1
# contains locality name
if sample.lower().find('road') != -1 or sample.lower().find('colony') !=-1 or sample.lower().find('marg') !=-1 or sample.lower().find('square') !=-1 or sample.lower().find('vihar') !=-1 or sample.lower().find('nagar') !=-1 or sample.lower().find('line') !=-1:
features[5] = 1
self.count['5'] += 1
# contains landmark name
if sample.lower().find('near') != -1 or sample.lower().find('behind') !=-1 or sample.lower().find('opp') !=-1 :
features[6] = 1
self.count['6'] += 1
self.Y.append(self.segments[segment])
self.X.append(features)
def decision_tree_generator(self):
self.load()
self.feature_extraction()
clf = tree.DecisionTreeClassifier()
print len(self.Y)
print self.count
clf = clf.fit(self.X, self.Y)
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("addresses.pdf")
if __name__ == '__main__':
LoadData().decision_tree_generator()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment