Created
October 2, 2020 18:26
-
-
Save finlytics-hub/6920ee1c70dc5bbe5a9ec8053798fec4 to your computer and use it in GitHub Desktop.
ODI: Model Training & Validation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# drop features we decided not to keep as a result of the Feature Selection Steps above | |
X = final_data.drop(columns = ['Ground', 'Match Date', 'Result', 'Toss Won?', 'Match Month', 'Country Total Bowling Rank', 'Country Total Batting Rank', 'Opposition Total Bowling Rank', 'Opposition Total Batting Rank', 'Country Average Bowling Rank', 'Country Average Batting Rank', 'Opposition Average Bowling Rank', 'Opposition Average Batting Rank', 'Country Median Bowling Rank', 'Country Median Batting Rank', 'Opposition Median Bowling Rank', 'Opposition Median Batting Rank']) | |
# get dummy variables of all the remaining categorical features | |
X = pd.get_dummies(X, columns = ['Country', 'Opposition', 'Home/Away'], drop_first = True) | |
# target variable after representing wins as 1 and losses as 0 | |
y = final_data['Result'].replace({'Won': 1, 'Lost': 0}) | |
# instantiate Logistic Regression with 'balanced' class_weight parameter just to make sure | |
reg = LogisticRegression(max_iter = 10000, class_weight = 'balanced') | |
# define the cross-validation criteria | |
cv = RepeatedStratifiedKFold(n_splits = 10, n_repeats = 3, random_state = 1) | |
# calculate cross-validated scores using the Accuracy metric | |
scores = cross_val_score(reg, X, y, scoring = 'accuracy', cv = cv) | |
# calculate & print the mean Accuracy score | |
Accuracy = np.mean(scores) | |
print(f'Mean Accuracy: {Accuracy:.4f}') | |
# ROC Curve | |
# fit the algorithm on the whole data | |
reg.fit(X, y) | |
# predict the probability of Wins | |
prob = reg.predict_proba(X)[:, 1] | |
# get the values required to plot a ROC curve | |
fpr, tpr, thresholds = roc_curve(y, prob) | |
# plot the ROC curve | |
plt.plot(fpr, tpr) | |
# plot a secondary diagonal line, with dashed line style and black color to represent a no-skill classifier | |
plt.plot(fpr, fpr, linestyle = '--', color = 'k') | |
# set the tile and axes labels | |
plt.title('ROC curve') | |
plt.xlabel('False positive rate') | |
plt.ylabel('True positive rate'); | |
# Calculate Youden's J-Statistic to identify the best threshhold | |
J = tpr - fpr | |
# locate the index of the largest J | |
ix = np.argmax(J) | |
best_thresh = thresholds[ix] | |
print(f'Best Threshold: {best_thresh:.4f}') | |
# save the model to disk | |
pickle.dump(reg, open('.../Final_LR_Model.pkl', 'wb')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment