Skip to content

Instantly share code, notes, and snippets.

@audiolion
Last active August 29, 2015 14:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save audiolion/9384d8855db4ab25b728 to your computer and use it in GitHub Desktop.
Save audiolion/9384d8855db4ab25b728 to your computer and use it in GitHub Desktop.
###
# Attribute Branching Impurity Calculator
# Zachary Friss
# CSCI 420
# Written for Python 3.4
###
from statistics import *
from math import *
def main():
# File to parse
fileName = "Storm_Weather_Data_v112.csv"
# Define our attributes and information about them.
f = open(fileName, "r")
data = f.read()
rows = data.split('\n')
column_headers = []
for row in rows:
split_row = row.split(",")
for field in split_row:
column_headers.append(field)
index = 0
for field in column_headers:
dataAttributes[index] = {"string": field, "index": index, "nodes": {}}
# Index of target value
targetValueIndex = 5
# Target Value "Yes/True/Positive" value
targetYesValue = "yes"
###
# Read the file line by line. Remove any extra spaces or new lines.
# Skip first line that has headings.
# Split up the values by the comma.
# Parse everything as an int and place it into data array.
###
data = []
firstLine = True
for line in open(fileName):
if firstLine:
firstLine = False
continue
line = line.strip().split(',')
entry = []
for attribute in line:
entry.append(attribute)
data.append(entry)
# Define totalEntries in data for weights
totalEntries = len(data)
# Iterate through the various attributes defined.
# Calculate splits based on the attribute
for attribute in dataAttributes:
index = attribute['index']
nodes = attribute['nodes']
for entry in data:
value = str(entry[index])
targetYes = entry[targetValueIndex]
# Value is something new initialize it.
if not value in nodes:
nodes[value] = {"targetYes": 0, "targetNo": 0}
nodes[value]['value'] = value
# Increment correct end value for datay entry.
if targetYes == targetYesValue:
nodes[value]["targetYes"] += 1
else:
nodes[value]["targetNo"] += 1
mixedGiniIndex = 0
mixedMisClassificationError = 0
mixedEntropy = 0
nodesArray = []
# Iterate through nodes for attribute
# Calculate the individual node's value
for nodeValue, node in nodes.items():
total = node['targetYes'] + node['targetNo']
weight = total / totalEntries
probTargetYes = node['targetYes'] / total
probTargetNo = node['targetNo'] / total
# Calculate Miscallsification Error
misClassificationError = 1 - max([probTargetYes, probTargetNo])
weightedMisClassificationError = weight * misClassificationError
mixedMisClassificationError += weightedMisClassificationError
# Calculate GINI
probTargetYesSquared = pow(probTargetYes, 2)
probTargetNoSquared = pow(probTargetNo, 2)
gini = 1 - (probTargetYesSquared + probTargetNoSquared)
weightedGini = weight * gini
mixedGiniIndex += weightedGini
# Calculate Entropy
entropy = (-1 * (probTargetYes) * log2(probTargetYes)) + (-1 * (probTargetNo) * log2(probTargetNo))
weightedEntropy = weight * entropy
mixedEntropy += weightedEntropy
node['entropy'] = entropy
nodesArray.append(node)
# Sort based on entropy
nodesArray.sort(key=lambda n: n['entropy'], reverse = True)
# Print out the branch's values
print("Branching on " + str(attribute['string']))
print("Weighted GINI index: "+str(round(mixedGiniIndex, 3)))
print("Weighted Entropy: "+str(round(mixedEntropy, 3)))
print("Weighted Misclassification error: "+str(round(mixedMisClassificationError, 3)))
print("Individual Node Values:")
lastIndex = len(nodesArray) - 1
for index, node in enumerate(nodesArray):
if index == 0:
print("Highest Entropy: (" + str(node['value'])+") " + str(node['entropy']))
elif index == lastIndex:
print("Lowest Entropy: (" + str(node['value'])+") " + str(node['entropy']))
else:
print("Middle Entropy: (" + str(node['value'])+") " + str(node['entropy']))
print()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment