Skip to content

Instantly share code, notes, and snippets.

@chyikwei
Created July 15, 2014 15:21
Show Gist options
  • Save chyikwei/b2613dcdf4ec17ef0da3 to your computer and use it in GitHub Desktop.
Save chyikwei/b2613dcdf4ec17ef0da3 to your computer and use it in GitHub Desktop.
kdd model 2
import pandas as pd
import numpy as np
from sklearn import metrics
from sklearn import cross_validation
# models
from sklearn import linear_model
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import normalize
# files required
#######################################################################
#outcomes_file_name = 'raw_data/outcomes.csv'
sample_submission_file_name = 'raw_data/sampleSubmission.csv'
resource_file_name = 'resource_by_projectid.csv'
#categpry_file_name = 'project_category_features_binary_na_filled.csv'
categpry_file_name = 'project_category_features_binary.csv'
main_file_name = 'downlaoded_data/features_sam.csv'
submit_file_name = 'predictions_0715.csv'
########################################################################
exclude_fields = [
'projectid', 'teacher_acctid', 'schoolid',
'date_posted', 'resource_types', 'main_0', 'resource_veondorids',
]
numerical_fields = [
'ULOCAL', 'ct_stud_all', 'ct_stud_azn', 'ct_stud_hsp', 'ct_stud_blk', 'ct_stud_wht', 'ct_teach_all',
'ct_tch_exc_proj', 'ct_tch_ttl_attempt', 'ct_tch_ref', 'bn_tch_ref', 'ct_sch_non_exc_proj',
'ct_sch_exc_proj', 'ct_sch_ttl_attempt', 'rt_sch_exc_proj', 'bn_sch_exc_proj', 'geo_clus_grp',
'rnk_ct_geo_non_exc_proj', 'rnk_ct_geo_exc_proj', 'rnk_ct_geo_ttl_attempt', 'ct_geo_non_exc_proj',
'ct_geo_exc_proj', 'rt_geo_exc_proj', 'bn_geo_exc_proj_hta', 'ct_geo_ttl_attempt', 'geo_clus_group',
'train', 'ct_open_120', 'ct_open_90', 'ct_open_60', 'ct_open_30']
def build_matrix(start_date):
cat_features = pd.read_csv(categpry_file_name)
main_file = pd.read_csv(main_file_name)
resource_df = pd.read_csv(resource_file_name)
for field in numerical_fields:
main_file[field] = main_file[field].fillna(main_file[field].median())
merge_1 = pd.merge(main_file, cat_features, on='projectid')
all_df =pd.merge(merge_1, resource_df, on='projectid')
train_df = all_df[all_df['date_posted'] < '2014-01-01']
#train_df = pd.merge(train_df, outcomes, on='projectid')
test_df = all_df[all_df['date_posted'] >= '2014-01-01']
part_train_df = train_df[train_df['date_posted'] >= start_date]
# remove outcome fileds
part_train_df = part_train_df.sort('projectid')
train_response = part_train_df['main_0'].apply(lambda x: float(x)).values
#part_train_df = part_train_df.loc[:, test_df.columns]
test_df = test_df.sort('projectid')
for df in [part_train_df, test_df]:
for field in exclude_fields:
del df[field]
train_X = np.array(part_train_df)
test_X = np.array(test_df)
return train_X, train_response, test_X
def main():
print "load file...."
full_train_X, full_train_y, test_X = build_matrix(start_date='2013-07-01')
# split
X_train, X_test, y_train, y_test = cross_validation.train_test_split(full_train_X, full_train_y, test_size=0.2)
# pick a model an change it parameter....
#lr = linear_model.LogisticRegression(class_weight={1: 1, 0: 1}, C=0.1)
gbc = GradientBoostingClassifier(n_estimators=100, max_depth=4, min_samples_split=5)
#rf = RandomForestClassifier(n_estimators=200, max_depth=8, min_samples_split=15)
clf = gbc
# run model and get train/test AUC
clf.fit(X_train, y_train)
train_preds = clf.predict_proba(X_train)[:,1]
train_auc = metrics.roc_auc_score(y_train, train_preds)
test_preds = clf.predict_proba(X_test)[:,1]
test_auc = metrics.roc_auc_score(y_test, test_preds)
print 'AUC train:%.4f, test:%.4f' % (train_auc, test_auc)
# run model on full_train_X & generate predction on test
clf.fit(full_train_X, full_train_y)
test_preds = clf.predict_proba(test_X)[:,1]
sample = pd.read_csv(sample_submission_file_name)
sample = sample.sort('projectid')
sample['is_exciting'] = test_preds
sample.to_csv(submit_file_name, index = False)
print "submission file generated: %s" % submit_file_name
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment