Created
March 31, 2013 06:17
-
-
Save sl8r000/5279746 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) | |
if not path in sys.path: | |
sys.path.insert(1, path) | |
del path | |
import knn | |
from knn.nearest_neighbor1 import NearestNeighbor | |
print 'Current path', os.getcwd() | |
ROW_PATH = 'make_it_so/1100_tweets.txt' | |
INT_INDICES = [0, 1, 3, 4, 5, 6, 7, 8, 10, 12, 13, 15, 16] | |
FLOAT_INDICES = [2] | |
BOOL_INDICES = [9] | |
STRING_INDICES = [11, 14] | |
TOS_CAST = ['GineokwKoenig', | |
'RealNichelle', | |
'TheRealNimoy', | |
'WilliamShatner', | |
'GeorgeTakei'] | |
TNG_CAST = ['levarburton', | |
'jonathansfrakes', | |
'BrentSpiner', | |
'SirPatStew', | |
'wilw'] | |
def get_all_rows(): | |
row_file = open(ROW_PATH) | |
try: | |
all_rows = row_file.readlines() | |
for i in range(len(all_rows)): | |
row = all_rows[i] | |
row = row.split() | |
for index in range(len(row)): | |
value = row[index] | |
if index in INT_INDICES: | |
value = int(value) | |
elif index in FLOAT_INDICES: | |
value = float(value) | |
elif index in BOOL_INDICES: | |
value = bool(value) | |
elif index in STRING_INDICES: | |
pass | |
else: | |
raise Exception( | |
'Unexpected object type for {}'.format(value)) | |
row[index] = value | |
all_rows[i] = row | |
return all_rows | |
except: | |
raise | |
finally: | |
row_file.close() | |
def main(): | |
model = NearestNeighbor() | |
all_rows = get_all_rows() | |
for i in range(1000): | |
model.learn_from_row(all_rows[i]) | |
mistakes = 0 | |
for i in range(1001, 1100): | |
actual_value = all_rows[i][14] | |
predicted_vaule = model.predict_missing_column(all_rows[i], 14) | |
print actual_value, predicted_vaule | |
if actual_value in TOS_CAST: | |
actual_cast = 'TOS' | |
elif actual_value in TNG_CAST: | |
actual_cast = 'TNG' | |
if predicted_vaule in TOS_CAST: | |
predicted_cast = 'TOS' | |
elif predicted_vaule in TNG_CAST: | |
predicted_cast = 'TNG' | |
print actual_cast, predicted_cast | |
if actual_cast != predicted_cast: | |
print '\twrong' | |
mistakes +=1 | |
print '' | |
print 'Total mistakes', mistakes | |
if __name__ == '__main__': | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment