This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
y = data_df["target"].values | |
x = data_df.drop(["target"],axis=1) | |
#Scaling - mandatory for knn | |
from sklearn.preprocessing import StandardScaler | |
ss = StandardScaler() | |
x = ss.fit_transform(x) | |
#SPlitting into train and test | |
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size = 0.3) # 70% training and 30% test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_score = [] | |
test_score = [] | |
k_vals = [] | |
for k in range(1, 21): | |
k_vals.append(k) | |
knn = KNeighborsClassifier(n_neighbors = k) | |
knn.fit(X_train, y_train) | |
tr_score = knn.score(X_train, y_train) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## score that comes from the testing set only | |
max_test_score = max(test_score) | |
test_scores_ind = [i for i, v in enumerate(test_score) if v == max_test_score] | |
print('Max test score {} and k = {}'.format(max_test_score * 100,list(map(lambda x: x+1, test_scores_ind)))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Setup a knn classifier with k neighbors | |
knn = KNeighborsClassifier(8) | |
knn.fit(X_train,y_train) | |
knn.score(X_test,y_test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
y_pred = knn.predict(X_test) | |
confusion_matrix(y_test,y_pred) | |
pd.crosstab(y_test, y_pred, rownames=['Actual'], colnames=['Predicted'], margins=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print(classification_report(y_test,y_pred)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
y_pred_proba = knn.predict_proba(X_test)[:,1] | |
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.figure(figsize=(10,8)) | |
plt.plot([0,1],[0,1],'k--') | |
plt.plot(fpr,tpr, label='Knn') | |
plt.xlabel('FPR') | |
plt.ylabel('TPR') | |
plt.title('Knn(n_neighbors=5) ROC curve') | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
roc_auc_score(y_test, y_pred_proba) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
precision, recall, thresholds = precision_recall_curve(y_test, y_pred_proba) | |
plt.figure(figsize=(10,8)) | |
plt.plot([0, 1], [0.5, 0.5],'k--') | |
plt.plot(recall, precision, label='Knn') | |
plt.xlabel('recall') | |
plt.ylabel('precision') | |
plt.title('Knn(n_neighbors=5) PRC curve') | |
plt.show() |