Created
November 11, 2016 04:06
-
-
Save benjamincohen1/bd616b40713851c4103e282a55800f5d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Thu Nov 10 21:27:59 2016 | |
@author: ben | |
""" | |
import csv | |
import numpy | |
import matplotlib.pyplot as plt | |
def age_or_mean(age): | |
if age == '': | |
age = 28.5 | |
else: | |
age = float(age) | |
return age | |
def predict(datapoint): | |
#we want to return the prediction about this datapoint | |
age = datapoint['Age'] | |
sex = datapoint['Sex'] | |
pclass = int(datapoint['Pclass']) | |
age = age_or_mean(age) | |
if age <= 15: | |
if pclass < 3: | |
return True | |
else: | |
return False | |
else: | |
if sex == 'female': | |
return True | |
else: | |
if pclass == 1: | |
return True | |
else: | |
return False | |
myfile = open('/Users/Ben/Desktop/train.csv').readlines() | |
myreader = csv.DictReader(myfile) | |
actuals = [] | |
predictions = [] | |
ages = [] | |
classes = [] | |
colors = [] | |
for x in myreader: | |
# print x['Name'] | |
# print x['Age'] | |
# print x['Sex'] | |
age = age_or_mean(x['Age']) | |
sex = x['Sex'] | |
pred = predict(x) | |
predictions.append(pred) | |
# print pred | |
ages.append(age) | |
classes.append(int(x['Pclass'])) | |
# if sex == 'male': | |
# sexes.append(1) | |
# else: | |
# sexes.append(0) | |
if x['Survived'] == '1': | |
actuals.append(True) | |
colors.append('g') | |
else: | |
actuals.append(False) | |
colors.append('r') | |
# print '-------' | |
corrects = 0 | |
for index in range(len(actuals)): | |
if actuals[index] == predictions[index]: | |
corrects += 1 | |
print "You got " + str(corrects/float(len(actuals))) + " percent correct" | |
plt.scatter(ages, sexes, color=colors) | |
plt.show() | |
# | |
#print numpy.mean(yesses), numpy.std(yesses) | |
#plt.hist(yesses, bins=2) | |
#plt.show() | |
# | |
#print '\n' | |
#print numpy.mean(nos), numpy.std(nos) | |
#plt.hist(nos, bins=2) | |
#plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment