Last active
January 16, 2017 05:32
-
-
Save thmain/6c28adf2d1951b6d15963b7299fe3bc4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math; #For pow and sqrt | |
from random import shuffle; | |
def ReadData(fileName): | |
#Read the file, splitting by lines | |
f = open(fileName,'r'); | |
lines = f.read().splitlines(); | |
f.close(); | |
#Split the first line by commas, remove the first element | |
#and save the rest into a list. | |
#The list holds the feature names of the data set. | |
features = lines[0].split(',')[:-1]; | |
items = []; | |
for i in range(1,len(lines)): | |
line = lines[i].split(','); | |
itemFeatures = {"Class" : line[-1]}; | |
for j in range(len(features)): | |
f = features[j]; #Get the feature at index j | |
v = float(line[j]); #Convert feature value to float | |
itemFeatures[f] = v; #Add feature value to dict | |
items.append(itemFeatures); | |
shuffle(items); | |
return items; | |
###_Auxiliary Function_### | |
def EuclideanDistance(x,y): | |
S = 0; #The sum of the squared differences of the elements | |
for key in x.keys(): | |
S += math.pow(x[key]-y[key],2); | |
return math.sqrt(S); #The square root of the sum | |
def CalculateNeighborsClass(neighbors,k): | |
count = {}; | |
for i in range(k): | |
if(neighbors[i][1] not in count): | |
#The class at the ith index is not in the count dict. | |
#Initialize it to 1. | |
count[neighbors[i][1]] = 1; | |
else: | |
#Found another item of class c[i]. Increment its counter. | |
count[neighbors[i][1]] += 1; | |
return count; | |
def FindMax(Dict): | |
maximum = -1; | |
classification = ""; | |
for key in Dict.keys(): | |
if(Dict[key] > maximum): | |
maximum = Dict[key]; | |
classification = key; | |
return classification,maximum; | |
###_Core Functions_### | |
def Classify(nItem, k, Items): | |
#Hold nearest neighbours. First item is distance, second class | |
neighbors = []; | |
for item in Items: | |
#Find Euclidean Distance | |
distance = EuclideanDistance(nItem,item); | |
#Update neighbors, | |
#either adding the current item in neighbors or not. | |
neighbors = UpdateNeighbors(neighbors,item,distance,k); | |
#Count the number of each class in neighbors | |
count = CalculateNeighborsClass(neighbors,k); | |
#Find the max in count, aka the class with the most appearances | |
return FindMax(count); | |
def UpdateNeighbors(neighbors,item,distance,k): | |
if(len(neighbors) < k): | |
#List is not full, add new item and sort | |
neighbors.append([distance,item["Class"]]); | |
neighbors = sorted(neighbors); | |
else: | |
#List is full | |
#Check if new item should be entered | |
if(neighbors[-1][0] > distance): | |
#If yes, replace the last element with new item | |
neighbors[-1] = [distance,item["Class"]]; | |
neighbors = sorted(neighbors); | |
return neighbors; | |
###_Main_### | |
def main(): | |
items = ReadData('data.txt'); | |
newItem = {'PW' : 1.4, 'PL' : 4.7, 'SW' : 3.2, 'SL' : 7.0}; | |
print Classify(newItem,3,items); | |
if __name__ == "__main__": | |
main(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment