Skip to content

Instantly share code, notes, and snippets.

@finlytics-hub
Created October 2, 2020 18:26
Show Gist options
  • Save finlytics-hub/6920ee1c70dc5bbe5a9ec8053798fec4 to your computer and use it in GitHub Desktop.
Save finlytics-hub/6920ee1c70dc5bbe5a9ec8053798fec4 to your computer and use it in GitHub Desktop.
ODI: Model Training & Validation
# drop features we decided not to keep as a result of the Feature Selection Steps above
X = final_data.drop(columns = ['Ground', 'Match Date', 'Result', 'Toss Won?', 'Match Month', 'Country Total Bowling Rank', 'Country Total Batting Rank', 'Opposition Total Bowling Rank', 'Opposition Total Batting Rank', 'Country Average Bowling Rank', 'Country Average Batting Rank', 'Opposition Average Bowling Rank', 'Opposition Average Batting Rank', 'Country Median Bowling Rank', 'Country Median Batting Rank', 'Opposition Median Bowling Rank', 'Opposition Median Batting Rank'])
# get dummy variables of all the remaining categorical features
X = pd.get_dummies(X, columns = ['Country', 'Opposition', 'Home/Away'], drop_first = True)
# target variable after representing wins as 1 and losses as 0
y = final_data['Result'].replace({'Won': 1, 'Lost': 0})
# instantiate Logistic Regression with 'balanced' class_weight parameter just to make sure
reg = LogisticRegression(max_iter = 10000, class_weight = 'balanced')
# define the cross-validation criteria
cv = RepeatedStratifiedKFold(n_splits = 10, n_repeats = 3, random_state = 1)
# calculate cross-validated scores using the Accuracy metric
scores = cross_val_score(reg, X, y, scoring = 'accuracy', cv = cv)
# calculate & print the mean Accuracy score
Accuracy = np.mean(scores)
print(f'Mean Accuracy: {Accuracy:.4f}')
# ROC Curve
# fit the algorithm on the whole data
reg.fit(X, y)
# predict the probability of Wins
prob = reg.predict_proba(X)[:, 1]
# get the values required to plot a ROC curve
fpr, tpr, thresholds = roc_curve(y, prob)
# plot the ROC curve
plt.plot(fpr, tpr)
# plot a secondary diagonal line, with dashed line style and black color to represent a no-skill classifier
plt.plot(fpr, fpr, linestyle = '--', color = 'k')
# set the tile and axes labels
plt.title('ROC curve')
plt.xlabel('False positive rate')
plt.ylabel('True positive rate');
# Calculate Youden's J-Statistic to identify the best threshhold
J = tpr - fpr
# locate the index of the largest J
ix = np.argmax(J)
best_thresh = thresholds[ix]
print(f'Best Threshold: {best_thresh:.4f}')
# save the model to disk
pickle.dump(reg, open('.../Final_LR_Model.pkl', 'wb'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment