Skip to content

Instantly share code, notes, and snippets.

@thuliumsystems
Created May 11, 2023 02:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thuliumsystems/db562b9fdb2efbbd55d3ae765059704d to your computer and use it in GitHub Desktop.
Save thuliumsystems/db562b9fdb2efbbd55d3ae765059704d to your computer and use it in GitHub Desktop.
GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import (
RandomForestClassifier,
AdaBoostClassifier,
GradientBoostingClassifier,
)
from sklearn.preprocessing import StandardScaler
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.model_selection import GridSearchCV
training = [
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1],
[21, 16, 12, 26, 6, 0, 0, 0, 0, 0, 1],
[2, 5, 2, 3, 3, 3, 2, 3, 2, 4, 1],
[13, 6, 3, 4, 4, 1, 1, 0, 0, 1, 1],
[2, 3, 4, 2, 5, 2, 1, 3, 2, 1, 1],
[2, 3, 2, 3, 4, 2, 4, 3, 2, 2, 1],
[1, 0, 1, 1, 1, 4, 33, 2, 9, 4, 0],
[6, 4, 3, 4, 2, 6, 2, 4, 1, 4, 1],
[29, 12, 3, 6, 6, 2, 1, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 10, 27, 8, 8, 13, 0],
[21, 5, 11, 12, 9, 0, 0, 0, 0, 0, 1],
[5, 2, 2, 2, 2, 4, 5, 3, 4, 2, 1],
[14, 20, 6, 8, 8, 0, 0, 0, 0, 0, 1],
[2, 3, 2, 1, 1, 3, 7, 9, 3, 3, 0],
[5, 2, 5, 2, 4, 2, 3, 3, 2, 1, 1],
[8, 6, 4, 3, 2, 3, 4, 2, 1, 3, 1],
[5, 2, 0, 0, 1, 7, 5, 9, 5, 2, 0],
[6, 19, 2, 3, 4, 2, 0, 1, 1, 0, 1],
[0, 0, 0, 0, 0, 6, 16, 6, 19, 22, 0],
[17, 5, 28, 23, 4, 0, 0, 0, 0, 0, 1],
[4, 15, 2, 7, 11, 0, 0, 0, 0, 0, 1],
[3, 5, 4, 3, 8, 1, 2, 1, 2, 3, 1],
[1, 1, 0, 1, 0, 7, 7, 5, 28, 6, 0],
[5, 5, 3, 3, 8, 1, 1, 2, 1, 1, 1],
[2, 1, 1, 1, 1, 9, 7, 8, 2, 6, 0],
[7, 5, 5, 2, 4, 2, 2, 1, 1, 1, 1],
[3, 23, 3, 5, 6, 2, 1, 2, 2, 1, 1],
[1, 1, 1, 0, 1, 8, 5, 3, 2, 6, 0],
[11, 9, 6, 1, 3, 0, 2, 1, 1, 2, 1],
[11, 3, 8, 2, 5, 1, 1, 2, 1, 1, 1],
[8, 11, 8, 5, 10, 1, 0, 0, 1, 0, 1],
[16, 19, 6, 8, 2, 0, 0, 0, 0, 1, 1],
[7, 3, 5, 3, 5, 5, 4, 1, 1, 2, 1],
[1, 1, 1, 0, 1, 24, 5, 5, 5, 4, 0],
[1, 0, 1, 0, 1, 12, 8, 10, 7, 4, 0],
[6, 16, 5, 16, 7, 1, 0, 1, 0, 0, 1],
[4, 4, 2, 2, 3, 5, 7, 1, 1, 2, 0],
[4, 8, 3, 3, 5, 2, 1, 1, 1, 3, 1],
[9, 10, 5, 7, 5, 1, 1, 1, 0, 0, 1],
[4, 1, 1, 1, 2, 22, 4, 19, 4, 2, 0],
[0, 0, 0, 0, 1, 13, 9, 10, 4, 11, 0],
[14, 14, 9, 4, 4, 1, 1, 0, 0, 0, 1],
[21, 24, 4, 8, 5, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 13, 6, 6, 34, 4, 0],
[7, 5, 3, 3, 1, 9, 2, 1, 3, 4, 1],
[0, 0, 0, 0, 0, 6, 10, 4, 13, 30, 0],
[5, 8, 10, 9, 3, 1, 0, 1, 0, 1, 1],
[7, 6, 5, 4, 2, 1, 3, 2, 1, 1, 1],
[1, 1, 1, 1, 1, 5, 5, 7, 5, 5, 0],
[1, 1, 2, 1, 2, 7, 4, 4, 3, 2, 0],
[1, 0, 0, 0, 0, 31, 11, 6, 5, 7, 0],
[1, 1, 3, 1, 1, 2, 4, 10, 4, 3, 0],
[1, 2, 1, 1, 3, 5, 8, 9, 3, 7, 0],
[7, 4, 4, 7, 3, 1, 2, 3, 2, 1, 1],
[0, 1, 0, 1, 0, 12, 7, 22, 4, 16, 0],
[1, 4, 3, 1, 1, 4, 5, 11, 2, 7, 0],
[3, 2, 2, 3, 9, 7, 8, 1, 1, 2, 0],
[3, 3, 1, 2, 2, 4, 7, 4, 3, 5, 0],
[6, 7, 29, 3, 6, 0, 1, 0, 1, 1, 1],
[0, 0, 1, 0, 0, 3, 10, 12, 10, 10, 0],
[1, 1, 1, 2, 1, 8, 4, 5, 5, 4, 0],
[0, 1, 0, 1, 0, 4, 22, 7, 4, 3, 0],
[5, 3, 4, 5, 3, 3, 1, 2, 2, 1, 1],
[1, 0, 1, 1, 0, 6, 8, 13, 6, 4, 0],
[2, 1, 1, 0, 0, 18, 6, 38, 5, 8, 0],
[1, 0, 3, 2, 2, 5, 8, 4, 3, 4, 0],
[5, 18, 6, 5, 5, 0, 1, 0, 0, 1, 1],
[0, 1, 2, 1, 3, 4, 6, 12, 2, 13, 0],
[0, 1, 1, 1, 0, 7, 15, 6, 4, 7, 0],
[0, 0, 1, 1, 1, 5, 18, 7, 3, 6, 0],
[6, 7, 5, 9, 4, 1, 1, 0, 1, 1, 1],
[0, 3, 0, 0, 0, 8, 6, 7, 19, 2, 0],
[0, 1, 0, 0, 0, 7, 13, 11, 13, 6, 0],
[0, 0, 1, 0, 0, 7, 5, 19, 8, 10, 0],
[0, 1, 2, 0, 0, 19, 28, 10, 7, 5, 0],
[17, 10, 10, 9, 5, 1, 0, 0, 0, 0, 1],
[20, 9, 7, 8, 14, 0, 0, 0, 0, 0, 1],
[0, 2, 0, 1, 0, 14, 4, 40, 4, 2, 0],
[0, 0, 0, 0, 0, 5, 10, 6, 11, 11, 0],
[2, 3, 3, 5, 3, 4, 5, 3, 2, 1, 0],
[17, 3, 3, 4, 12, 2, 1, 0, 1, 1, 1],
[7, 3, 6, 5, 4, 5, 2, 1, 1, 0, 1],
[3, 3, 1, 1, 1, 4, 1, 6, 2, 3, 0],
[4, 5, 3, 2, 4, 1, 4, 1, 1, 0, 1],
[5, 7, 5, 2, 10, 1, 1, 1, 1, 2, 1],
[1, 1, 0, 0, 1, 7, 5, 13, 10, 6, 0],
[1, 0, 0, 1, 0, 21, 8, 6, 3, 9, 0],
[8, 3, 1, 3, 5, 2, 3, 0, 2, 2, 1],
[4, 1, 0, 1, 2, 20, 6, 2, 5, 3, 0],
[1, 1, 1, 0, 0, 4, 11, 6, 4, 9, 0],
[9, 3, 8, 2, 2, 2, 3, 0, 1, 1, 1],
[7, 5, 6, 4, 2, 4, 3, 1, 1, 1, 1],
[10, 6, 4, 6, 7, 0, 2, 1, 0, 1, 1],
[4, 6, 5, 8, 8, 1, 0, 1, 1, 0, 1],
[3, 4, 4, 3, 7, 2, 3, 1, 1, 2, 1],
[6, 9, 10, 4, 8, 1, 1, 1, 2, 1, 1],
[5, 9, 4, 3, 3, 1, 1, 0, 0, 1, 1],
[9, 1, 2, 2, 2, 5, 5, 3, 3, 4, 0],
[0, 0, 2, 1, 0, 9, 6, 6, 9, 7, 0],
[14, 11, 7, 4, 2, 2, 1, 1, 1, 1, 1],
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1],
[7, 6, 7, 4, 22, 0, 1, 0, 0, 1, 1],
[10, 2, 8, 2, 2, 2, 3, 1, 4, 2, 1],
[11, 32, 15, 8, 5, 0, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 3, 14, 7, 6, 5, 0],
[6, 19, 8, 19, 16, 0, 0, 0, 0, 0, 1],
[2, 3, 1, 0, 1, 13, 6, 5, 3, 2, 0],
[17, 17, 36, 3, 3, 0, 0, 0, 1, 1, 1],
[1, 3, 1, 2, 1, 16, 4, 4, 1, 5, 0],
[9, 3, 21, 7, 3, 1, 0, 1, 1, 0, 1],
[17, 18, 3, 19, 5, 1, 0, 0, 0, 0, 1],
[1, 3, 1, 1, 2, 11, 3, 4, 3, 4, 0],
[6, 3, 6, 3, 2, 1, 2, 3, 2, 3, 1],
[2, 2, 2, 2, 2, 4, 3, 5, 3, 6, 0],
[6, 14, 11, 4, 17, 0, 0, 1, 0, 0, 1],
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1],
[25, 5, 3, 3, 3, 1, 1, 0, 1, 2, 1],
[6, 31, 6, 5, 8, 0, 1, 0, 0, 0, 1],
[4, 3, 3, 2, 1, 3, 4, 2, 2, 4, 1],
[1, 1, 1, 1, 0, 5, 5, 13, 13, 10, 0],
[18, 9, 18, 5, 19, 0, 0, 0, 0, 0, 1],
[0, 1, 2, 0, 0, 4, 14, 5, 3, 6, 0],
[5, 6, 2, 3, 2, 2, 2, 2, 2, 1, 1],
[23, 5, 2, 3, 4, 4, 2, 1, 2, 1, 1],
[7, 5, 8, 8, 7, 2, 1, 1, 0, 0, 1],
[1, 3, 1, 2, 2, 10, 7, 4, 3, 4, 0],
[3, 3, 2, 2, 2, 2, 6, 4, 2, 3, 1],
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1],
[15, 6, 11, 7, 5, 0, 1, 0, 1, 0, 1],
[11, 3, 8, 2, 5, 1, 1, 2, 1, 1, 1],
[11, 5, 11, 3, 3, 1, 2, 1, 1, 1, 1],
[8, 29, 6, 9, 7, 1, 0, 0, 0, 1, 1],
[6, 16, 5, 16, 7, 1, 0, 1, 0, 0, 1],
[3, 2, 8, 2, 1, 3, 3, 3, 1, 1, 1],
[21, 24, 4, 8, 5, 0, 0, 0, 0, 0, 1],
[0, 1, 1, 1, 1, 6, 8, 5, 17, 4, 0],
[19, 16, 12, 4, 4, 0, 0, 1, 1, 0, 1],
[3, 1, 2, 1, 1, 10, 7, 8, 3, 10, 0],
[1, 0, 0, 0, 0, 31, 11, 6, 5, 7, 0],
[16, 7, 18, 8, 2, 0, 1, 1, 0, 0, 1],
[0, 1, 0, 0, 0, 7, 13, 11, 13, 6, 0],
[0, 0, 1, 0, 0, 7, 5, 19, 8, 10, 0],
[17, 10, 10, 9, 5, 1, 0, 0, 0, 0, 1],
[2, 1, 0, 1, 1, 14, 10, 6, 5, 5, 0],
[12, 7, 4, 5, 6, 0, 1, 1, 0, 0, 1],
[1, 1, 1, 0, 0, 8, 13, 9, 9, 6, 0],
[1, 2, 1, 1, 1, 18, 7, 9, 3, 1, 0],
[7, 2, 3, 4, 2, 3, 5, 3, 3, 2, 1],
[2, 3, 3, 5, 3, 4, 5, 3, 2, 1, 0],
[10, 5, 7, 3, 3, 1, 2, 1, 1, 1, 1],
[8, 9, 5, 5, 2, 2, 2, 3, 0, 0, 1],
[4, 3, 10, 3, 1, 1, 3, 3, 4, 1, 1],
[7, 3, 8, 3, 8, 1, 1, 1, 2, 1, 1],
[1, 3, 4, 1, 3, 2, 2, 5, 2, 1, 1],
[1, 1, 0, 0, 1, 7, 5, 13, 10, 6, 0],
[6, 5, 4, 2, 2, 1, 4, 2, 1, 1, 1],
[1, 3, 1, 2, 1, 16, 4, 4, 1, 5, 0],
[4, 6, 4, 7, 3, 2, 1, 2, 2, 1, 1],
[4, 4, 16, 5, 16, 3, 0, 1, 0, 1, 1],
[5, 10, 14, 16, 3, 0, 0, 0, 0, 0, 1],
[23, 5, 2, 3, 4, 4, 2, 1, 2, 1, 1],
[2, 1, 2, 1, 3, 2, 5, 5, 3, 3, 0],
[5, 5, 8, 4, 2, 3, 3, 1, 1, 2, 1],
[3, 5, 3, 2, 4, 2, 3, 3, 0, 2, 1],
[6, 19, 4, 4, 8, 0, 1, 0, 0, 0, 1],
[1, 2, 0, 1, 1, 19, 3, 6, 9, 3, 0],
[2, 2, 2, 0, 1, 4, 10, 9, 3, 5, 0],
[2, 0, 0, 0, 0, 13, 6, 23, 9, 11, 0],
[5, 34, 14, 14, 4, 0, 0, 1, 1, 0, 1],
[19, 2, 5, 2, 21, 0, 6, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 21, 18, 20, 6, 9, 0],
[13, 16, 22, 5, 13, 0, 0, 0, 0, 0, 1],
[6, 1, 1, 1, 0, 5, 12, 6, 5, 2, 0],
[5, 12, 4, 10, 4, 2, 1, 1, 1, 1, 1],
[17, 18, 6, 8, 10, 1, 0, 0, 0, 0, 1],
[9, 15, 10, 2, 4, 2, 1, 1, 2, 3, 1],
[1, 0, 0, 0, 0, 16, 19, 10, 4, 6, 0],
[6, 5, 3, 1, 3, 3, 2, 2, 1, 1, 1],
[6, 4, 5, 8, 8, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 4, 5, 22, 3, 4, 0],
[2, 2, 1, 0, 1, 7, 11, 6, 2, 3, 0],
[6, 10, 3, 2, 5, 1, 4, 0, 1, 0, 1],
[2, 2, 5, 2, 1, 4, 4, 3, 4, 2, 1],
[2, 1, 0, 1, 1, 5, 12, 7, 7, 4, 0],
[8, 22, 4, 4, 3, 1, 1, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 26, 15, 15, 7, 9, 0],
[10, 5, 3, 3, 3, 2, 2, 3, 1, 1, 1],
[0, 0, 0, 0, 0, 21, 9, 21, 13, 2, 0],
[13, 4, 7, 7, 5, 1, 1, 1, 1, 0, 1],
[12, 25, 6, 2, 6, 5, 1, 0, 0, 1, 1],
[17, 4, 1, 1, 1, 4, 4, 8, 1, 4, 0],
[0, 0, 0, 0, 0, 10, 22, 19, 5, 9, 0],
[2, 5, 4, 2, 2, 9, 2, 2, 1, 1, 0],
]
scaler = StandardScaler()
scaled_data = scaler.fit_transform(training)
X = []
for t in training:
X.append(t[:-1])
y = [last for *_, last in training]
param_rf = {
"n_estimators": [50, 100, 200],
"max_depth": [None, 5, 10],
"min_samples_split": [2, 5, 10],
}
param_dt = {
"max_depth": [None, 5, 10],
"min_samples_split": [2, 5, 10],
"min_samples_leaf": [1, 2, 4],
}
param_lr = {"penalty": ["l1", "l2"], "C": [0.001, 0.01, 0.1, 1, 10]}
param_ab = {"n_estimators": [50, 100, 200], "learning_rate": [0.001, 0.01, 0.1, 1, 10]}
param_gbc = {"n_estimators": [50, 100, 200], "learning_rate": [0.001, 0.01, 0.1, 1, 10]}
param_gnb = {"var_smoothing": [1e-9, 1e-8, 1e-7, 1e-6, 1e-5]}
param_knn = {"n_neighbors": [3, 5, 7, 9, 11]}
param_mlp = {
"hidden_layer_sizes": [(10,), (50,), (100,)],
"activation": ["tanh", "relu"],
"alpha": [0.0001, 0.001, 0.01],
}
param_svc = {
"C": [0.1, 1, 10, 100],
"gamma": [0.001, 0.01, 0.1, 1],
"kernel": ["linear", "rbf"],
}
param_qda = {"reg_param": [0.1, 0.5, 1.0]}
def fn_grid_search(model, params, X, y):
grid_search = GridSearchCV(model, params, cv=5)
grid_search.fit(X, y)
print(grid_search.best_score_, "Parameters:", grid_search.best_params_)
rfc = RandomForestClassifier()
dtc = DecisionTreeClassifier()
lr = LogisticRegression()
ab = AdaBoostClassifier()
gbc = GradientBoostingClassifier()
gnb = GaussianNB()
knn = KNeighborsClassifier()
mlp = MLPClassifier()
svc = SVC()
qda = QuadraticDiscriminantAnalysis()
fn_grid_search(rfc, param_rf, X, y)
fn_grid_search(dtc, param_dt, X, y)
# fn_grid_search(lr, param_lr, X, y)
fn_grid_search(ab, param_ab, X, y)
fn_grid_search(gbc, param_gbc, X, y)
fn_grid_search(gnb, param_gnb, X, y)
fn_grid_search(knn, param_knn, X, y)
# fn_grid_search(mlp, param_mlp, X, y)
fn_grid_search(svc, param_svc, X, y)
fn_grid_search(qda, param_qda, X, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment