Last active
August 29, 2015 14:16
-
-
Save audiolion/9384d8855db4ab25b728 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### | |
# Attribute Branching Impurity Calculator | |
# Zachary Friss | |
# CSCI 420 | |
# Written for Python 3.4 | |
### | |
from statistics import * | |
from math import * | |
def main(): | |
# File to parse | |
fileName = "Storm_Weather_Data_v112.csv" | |
# Define our attributes and information about them. | |
f = open(fileName, "r") | |
data = f.read() | |
rows = data.split('\n') | |
column_headers = [] | |
for row in rows: | |
split_row = row.split(",") | |
for field in split_row: | |
column_headers.append(field) | |
index = 0 | |
for field in column_headers: | |
dataAttributes[index] = {"string": field, "index": index, "nodes": {}} | |
# Index of target value | |
targetValueIndex = 5 | |
# Target Value "Yes/True/Positive" value | |
targetYesValue = "yes" | |
### | |
# Read the file line by line. Remove any extra spaces or new lines. | |
# Skip first line that has headings. | |
# Split up the values by the comma. | |
# Parse everything as an int and place it into data array. | |
### | |
data = [] | |
firstLine = True | |
for line in open(fileName): | |
if firstLine: | |
firstLine = False | |
continue | |
line = line.strip().split(',') | |
entry = [] | |
for attribute in line: | |
entry.append(attribute) | |
data.append(entry) | |
# Define totalEntries in data for weights | |
totalEntries = len(data) | |
# Iterate through the various attributes defined. | |
# Calculate splits based on the attribute | |
for attribute in dataAttributes: | |
index = attribute['index'] | |
nodes = attribute['nodes'] | |
for entry in data: | |
value = str(entry[index]) | |
targetYes = entry[targetValueIndex] | |
# Value is something new initialize it. | |
if not value in nodes: | |
nodes[value] = {"targetYes": 0, "targetNo": 0} | |
nodes[value]['value'] = value | |
# Increment correct end value for datay entry. | |
if targetYes == targetYesValue: | |
nodes[value]["targetYes"] += 1 | |
else: | |
nodes[value]["targetNo"] += 1 | |
mixedGiniIndex = 0 | |
mixedMisClassificationError = 0 | |
mixedEntropy = 0 | |
nodesArray = [] | |
# Iterate through nodes for attribute | |
# Calculate the individual node's value | |
for nodeValue, node in nodes.items(): | |
total = node['targetYes'] + node['targetNo'] | |
weight = total / totalEntries | |
probTargetYes = node['targetYes'] / total | |
probTargetNo = node['targetNo'] / total | |
# Calculate Miscallsification Error | |
misClassificationError = 1 - max([probTargetYes, probTargetNo]) | |
weightedMisClassificationError = weight * misClassificationError | |
mixedMisClassificationError += weightedMisClassificationError | |
# Calculate GINI | |
probTargetYesSquared = pow(probTargetYes, 2) | |
probTargetNoSquared = pow(probTargetNo, 2) | |
gini = 1 - (probTargetYesSquared + probTargetNoSquared) | |
weightedGini = weight * gini | |
mixedGiniIndex += weightedGini | |
# Calculate Entropy | |
entropy = (-1 * (probTargetYes) * log2(probTargetYes)) + (-1 * (probTargetNo) * log2(probTargetNo)) | |
weightedEntropy = weight * entropy | |
mixedEntropy += weightedEntropy | |
node['entropy'] = entropy | |
nodesArray.append(node) | |
# Sort based on entropy | |
nodesArray.sort(key=lambda n: n['entropy'], reverse = True) | |
# Print out the branch's values | |
print("Branching on " + str(attribute['string'])) | |
print("Weighted GINI index: "+str(round(mixedGiniIndex, 3))) | |
print("Weighted Entropy: "+str(round(mixedEntropy, 3))) | |
print("Weighted Misclassification error: "+str(round(mixedMisClassificationError, 3))) | |
print("Individual Node Values:") | |
lastIndex = len(nodesArray) - 1 | |
for index, node in enumerate(nodesArray): | |
if index == 0: | |
print("Highest Entropy: (" + str(node['value'])+") " + str(node['entropy'])) | |
elif index == lastIndex: | |
print("Lowest Entropy: (" + str(node['value'])+") " + str(node['entropy'])) | |
else: | |
print("Middle Entropy: (" + str(node['value'])+") " + str(node['entropy'])) | |
print() | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment