Created
May 11, 2023 02:44
-
-
Save thuliumsystems/db562b9fdb2efbbd55d3ae765059704d to your computer and use it in GitHub Desktop.
GridSearchCV
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.tree import DecisionTreeClassifier | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.ensemble import ( | |
RandomForestClassifier, | |
AdaBoostClassifier, | |
GradientBoostingClassifier, | |
) | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.naive_bayes import GaussianNB | |
from sklearn.neighbors import KNeighborsClassifier | |
from sklearn.neural_network import MLPClassifier | |
from sklearn.svm import SVC | |
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis | |
from sklearn.model_selection import GridSearchCV | |
training = [ | |
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1], | |
[21, 16, 12, 26, 6, 0, 0, 0, 0, 0, 1], | |
[2, 5, 2, 3, 3, 3, 2, 3, 2, 4, 1], | |
[13, 6, 3, 4, 4, 1, 1, 0, 0, 1, 1], | |
[2, 3, 4, 2, 5, 2, 1, 3, 2, 1, 1], | |
[2, 3, 2, 3, 4, 2, 4, 3, 2, 2, 1], | |
[1, 0, 1, 1, 1, 4, 33, 2, 9, 4, 0], | |
[6, 4, 3, 4, 2, 6, 2, 4, 1, 4, 1], | |
[29, 12, 3, 6, 6, 2, 1, 0, 0, 1, 1], | |
[0, 0, 0, 0, 0, 10, 27, 8, 8, 13, 0], | |
[21, 5, 11, 12, 9, 0, 0, 0, 0, 0, 1], | |
[5, 2, 2, 2, 2, 4, 5, 3, 4, 2, 1], | |
[14, 20, 6, 8, 8, 0, 0, 0, 0, 0, 1], | |
[2, 3, 2, 1, 1, 3, 7, 9, 3, 3, 0], | |
[5, 2, 5, 2, 4, 2, 3, 3, 2, 1, 1], | |
[8, 6, 4, 3, 2, 3, 4, 2, 1, 3, 1], | |
[5, 2, 0, 0, 1, 7, 5, 9, 5, 2, 0], | |
[6, 19, 2, 3, 4, 2, 0, 1, 1, 0, 1], | |
[0, 0, 0, 0, 0, 6, 16, 6, 19, 22, 0], | |
[17, 5, 28, 23, 4, 0, 0, 0, 0, 0, 1], | |
[4, 15, 2, 7, 11, 0, 0, 0, 0, 0, 1], | |
[3, 5, 4, 3, 8, 1, 2, 1, 2, 3, 1], | |
[1, 1, 0, 1, 0, 7, 7, 5, 28, 6, 0], | |
[5, 5, 3, 3, 8, 1, 1, 2, 1, 1, 1], | |
[2, 1, 1, 1, 1, 9, 7, 8, 2, 6, 0], | |
[7, 5, 5, 2, 4, 2, 2, 1, 1, 1, 1], | |
[3, 23, 3, 5, 6, 2, 1, 2, 2, 1, 1], | |
[1, 1, 1, 0, 1, 8, 5, 3, 2, 6, 0], | |
[11, 9, 6, 1, 3, 0, 2, 1, 1, 2, 1], | |
[11, 3, 8, 2, 5, 1, 1, 2, 1, 1, 1], | |
[8, 11, 8, 5, 10, 1, 0, 0, 1, 0, 1], | |
[16, 19, 6, 8, 2, 0, 0, 0, 0, 1, 1], | |
[7, 3, 5, 3, 5, 5, 4, 1, 1, 2, 1], | |
[1, 1, 1, 0, 1, 24, 5, 5, 5, 4, 0], | |
[1, 0, 1, 0, 1, 12, 8, 10, 7, 4, 0], | |
[6, 16, 5, 16, 7, 1, 0, 1, 0, 0, 1], | |
[4, 4, 2, 2, 3, 5, 7, 1, 1, 2, 0], | |
[4, 8, 3, 3, 5, 2, 1, 1, 1, 3, 1], | |
[9, 10, 5, 7, 5, 1, 1, 1, 0, 0, 1], | |
[4, 1, 1, 1, 2, 22, 4, 19, 4, 2, 0], | |
[0, 0, 0, 0, 1, 13, 9, 10, 4, 11, 0], | |
[14, 14, 9, 4, 4, 1, 1, 0, 0, 0, 1], | |
[21, 24, 4, 8, 5, 0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 13, 6, 6, 34, 4, 0], | |
[7, 5, 3, 3, 1, 9, 2, 1, 3, 4, 1], | |
[0, 0, 0, 0, 0, 6, 10, 4, 13, 30, 0], | |
[5, 8, 10, 9, 3, 1, 0, 1, 0, 1, 1], | |
[7, 6, 5, 4, 2, 1, 3, 2, 1, 1, 1], | |
[1, 1, 1, 1, 1, 5, 5, 7, 5, 5, 0], | |
[1, 1, 2, 1, 2, 7, 4, 4, 3, 2, 0], | |
[1, 0, 0, 0, 0, 31, 11, 6, 5, 7, 0], | |
[1, 1, 3, 1, 1, 2, 4, 10, 4, 3, 0], | |
[1, 2, 1, 1, 3, 5, 8, 9, 3, 7, 0], | |
[7, 4, 4, 7, 3, 1, 2, 3, 2, 1, 1], | |
[0, 1, 0, 1, 0, 12, 7, 22, 4, 16, 0], | |
[1, 4, 3, 1, 1, 4, 5, 11, 2, 7, 0], | |
[3, 2, 2, 3, 9, 7, 8, 1, 1, 2, 0], | |
[3, 3, 1, 2, 2, 4, 7, 4, 3, 5, 0], | |
[6, 7, 29, 3, 6, 0, 1, 0, 1, 1, 1], | |
[0, 0, 1, 0, 0, 3, 10, 12, 10, 10, 0], | |
[1, 1, 1, 2, 1, 8, 4, 5, 5, 4, 0], | |
[0, 1, 0, 1, 0, 4, 22, 7, 4, 3, 0], | |
[5, 3, 4, 5, 3, 3, 1, 2, 2, 1, 1], | |
[1, 0, 1, 1, 0, 6, 8, 13, 6, 4, 0], | |
[2, 1, 1, 0, 0, 18, 6, 38, 5, 8, 0], | |
[1, 0, 3, 2, 2, 5, 8, 4, 3, 4, 0], | |
[5, 18, 6, 5, 5, 0, 1, 0, 0, 1, 1], | |
[0, 1, 2, 1, 3, 4, 6, 12, 2, 13, 0], | |
[0, 1, 1, 1, 0, 7, 15, 6, 4, 7, 0], | |
[0, 0, 1, 1, 1, 5, 18, 7, 3, 6, 0], | |
[6, 7, 5, 9, 4, 1, 1, 0, 1, 1, 1], | |
[0, 3, 0, 0, 0, 8, 6, 7, 19, 2, 0], | |
[0, 1, 0, 0, 0, 7, 13, 11, 13, 6, 0], | |
[0, 0, 1, 0, 0, 7, 5, 19, 8, 10, 0], | |
[0, 1, 2, 0, 0, 19, 28, 10, 7, 5, 0], | |
[17, 10, 10, 9, 5, 1, 0, 0, 0, 0, 1], | |
[20, 9, 7, 8, 14, 0, 0, 0, 0, 0, 1], | |
[0, 2, 0, 1, 0, 14, 4, 40, 4, 2, 0], | |
[0, 0, 0, 0, 0, 5, 10, 6, 11, 11, 0], | |
[2, 3, 3, 5, 3, 4, 5, 3, 2, 1, 0], | |
[17, 3, 3, 4, 12, 2, 1, 0, 1, 1, 1], | |
[7, 3, 6, 5, 4, 5, 2, 1, 1, 0, 1], | |
[3, 3, 1, 1, 1, 4, 1, 6, 2, 3, 0], | |
[4, 5, 3, 2, 4, 1, 4, 1, 1, 0, 1], | |
[5, 7, 5, 2, 10, 1, 1, 1, 1, 2, 1], | |
[1, 1, 0, 0, 1, 7, 5, 13, 10, 6, 0], | |
[1, 0, 0, 1, 0, 21, 8, 6, 3, 9, 0], | |
[8, 3, 1, 3, 5, 2, 3, 0, 2, 2, 1], | |
[4, 1, 0, 1, 2, 20, 6, 2, 5, 3, 0], | |
[1, 1, 1, 0, 0, 4, 11, 6, 4, 9, 0], | |
[9, 3, 8, 2, 2, 2, 3, 0, 1, 1, 1], | |
[7, 5, 6, 4, 2, 4, 3, 1, 1, 1, 1], | |
[10, 6, 4, 6, 7, 0, 2, 1, 0, 1, 1], | |
[4, 6, 5, 8, 8, 1, 0, 1, 1, 0, 1], | |
[3, 4, 4, 3, 7, 2, 3, 1, 1, 2, 1], | |
[6, 9, 10, 4, 8, 1, 1, 1, 2, 1, 1], | |
[5, 9, 4, 3, 3, 1, 1, 0, 0, 1, 1], | |
[9, 1, 2, 2, 2, 5, 5, 3, 3, 4, 0], | |
[0, 0, 2, 1, 0, 9, 6, 6, 9, 7, 0], | |
[14, 11, 7, 4, 2, 2, 1, 1, 1, 1, 1], | |
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1], | |
[7, 6, 7, 4, 22, 0, 1, 0, 0, 1, 1], | |
[10, 2, 8, 2, 2, 2, 3, 1, 4, 2, 1], | |
[11, 32, 15, 8, 5, 0, 0, 0, 0, 0, 1], | |
[1, 1, 1, 1, 0, 3, 14, 7, 6, 5, 0], | |
[6, 19, 8, 19, 16, 0, 0, 0, 0, 0, 1], | |
[2, 3, 1, 0, 1, 13, 6, 5, 3, 2, 0], | |
[17, 17, 36, 3, 3, 0, 0, 0, 1, 1, 1], | |
[1, 3, 1, 2, 1, 16, 4, 4, 1, 5, 0], | |
[9, 3, 21, 7, 3, 1, 0, 1, 1, 0, 1], | |
[17, 18, 3, 19, 5, 1, 0, 0, 0, 0, 1], | |
[1, 3, 1, 1, 2, 11, 3, 4, 3, 4, 0], | |
[6, 3, 6, 3, 2, 1, 2, 3, 2, 3, 1], | |
[2, 2, 2, 2, 2, 4, 3, 5, 3, 6, 0], | |
[6, 14, 11, 4, 17, 0, 0, 1, 0, 0, 1], | |
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1], | |
[25, 5, 3, 3, 3, 1, 1, 0, 1, 2, 1], | |
[6, 31, 6, 5, 8, 0, 1, 0, 0, 0, 1], | |
[4, 3, 3, 2, 1, 3, 4, 2, 2, 4, 1], | |
[1, 1, 1, 1, 0, 5, 5, 13, 13, 10, 0], | |
[18, 9, 18, 5, 19, 0, 0, 0, 0, 0, 1], | |
[0, 1, 2, 0, 0, 4, 14, 5, 3, 6, 0], | |
[5, 6, 2, 3, 2, 2, 2, 2, 2, 1, 1], | |
[23, 5, 2, 3, 4, 4, 2, 1, 2, 1, 1], | |
[7, 5, 8, 8, 7, 2, 1, 1, 0, 0, 1], | |
[1, 3, 1, 2, 2, 10, 7, 4, 3, 4, 0], | |
[3, 3, 2, 2, 2, 2, 6, 4, 2, 3, 1], | |
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1], | |
[15, 6, 11, 7, 5, 0, 1, 0, 1, 0, 1], | |
[11, 3, 8, 2, 5, 1, 1, 2, 1, 1, 1], | |
[11, 5, 11, 3, 3, 1, 2, 1, 1, 1, 1], | |
[8, 29, 6, 9, 7, 1, 0, 0, 0, 1, 1], | |
[6, 16, 5, 16, 7, 1, 0, 1, 0, 0, 1], | |
[3, 2, 8, 2, 1, 3, 3, 3, 1, 1, 1], | |
[21, 24, 4, 8, 5, 0, 0, 0, 0, 0, 1], | |
[0, 1, 1, 1, 1, 6, 8, 5, 17, 4, 0], | |
[19, 16, 12, 4, 4, 0, 0, 1, 1, 0, 1], | |
[3, 1, 2, 1, 1, 10, 7, 8, 3, 10, 0], | |
[1, 0, 0, 0, 0, 31, 11, 6, 5, 7, 0], | |
[16, 7, 18, 8, 2, 0, 1, 1, 0, 0, 1], | |
[0, 1, 0, 0, 0, 7, 13, 11, 13, 6, 0], | |
[0, 0, 1, 0, 0, 7, 5, 19, 8, 10, 0], | |
[17, 10, 10, 9, 5, 1, 0, 0, 0, 0, 1], | |
[2, 1, 0, 1, 1, 14, 10, 6, 5, 5, 0], | |
[12, 7, 4, 5, 6, 0, 1, 1, 0, 0, 1], | |
[1, 1, 1, 0, 0, 8, 13, 9, 9, 6, 0], | |
[1, 2, 1, 1, 1, 18, 7, 9, 3, 1, 0], | |
[7, 2, 3, 4, 2, 3, 5, 3, 3, 2, 1], | |
[2, 3, 3, 5, 3, 4, 5, 3, 2, 1, 0], | |
[10, 5, 7, 3, 3, 1, 2, 1, 1, 1, 1], | |
[8, 9, 5, 5, 2, 2, 2, 3, 0, 0, 1], | |
[4, 3, 10, 3, 1, 1, 3, 3, 4, 1, 1], | |
[7, 3, 8, 3, 8, 1, 1, 1, 2, 1, 1], | |
[1, 3, 4, 1, 3, 2, 2, 5, 2, 1, 1], | |
[1, 1, 0, 0, 1, 7, 5, 13, 10, 6, 0], | |
[6, 5, 4, 2, 2, 1, 4, 2, 1, 1, 1], | |
[1, 3, 1, 2, 1, 16, 4, 4, 1, 5, 0], | |
[4, 6, 4, 7, 3, 2, 1, 2, 2, 1, 1], | |
[4, 4, 16, 5, 16, 3, 0, 1, 0, 1, 1], | |
[5, 10, 14, 16, 3, 0, 0, 0, 0, 0, 1], | |
[23, 5, 2, 3, 4, 4, 2, 1, 2, 1, 1], | |
[2, 1, 2, 1, 3, 2, 5, 5, 3, 3, 0], | |
[5, 5, 8, 4, 2, 3, 3, 1, 1, 2, 1], | |
[3, 5, 3, 2, 4, 2, 3, 3, 0, 2, 1], | |
[6, 19, 4, 4, 8, 0, 1, 0, 0, 0, 1], | |
[1, 2, 0, 1, 1, 19, 3, 6, 9, 3, 0], | |
[2, 2, 2, 0, 1, 4, 10, 9, 3, 5, 0], | |
[2, 0, 0, 0, 0, 13, 6, 23, 9, 11, 0], | |
[5, 34, 14, 14, 4, 0, 0, 1, 1, 0, 1], | |
[19, 2, 5, 2, 21, 0, 6, 1, 1, 1, 1], | |
[0, 0, 0, 0, 0, 21, 18, 20, 6, 9, 0], | |
[13, 16, 22, 5, 13, 0, 0, 0, 0, 0, 1], | |
[6, 1, 1, 1, 0, 5, 12, 6, 5, 2, 0], | |
[5, 12, 4, 10, 4, 2, 1, 1, 1, 1, 1], | |
[17, 18, 6, 8, 10, 1, 0, 0, 0, 0, 1], | |
[9, 15, 10, 2, 4, 2, 1, 1, 2, 3, 1], | |
[1, 0, 0, 0, 0, 16, 19, 10, 4, 6, 0], | |
[6, 5, 3, 1, 3, 3, 2, 2, 1, 1, 1], | |
[6, 4, 5, 8, 8, 1, 1, 0, 0, 0, 1], | |
[1, 1, 1, 1, 1, 4, 5, 22, 3, 4, 0], | |
[2, 2, 1, 0, 1, 7, 11, 6, 2, 3, 0], | |
[6, 10, 3, 2, 5, 1, 4, 0, 1, 0, 1], | |
[2, 2, 5, 2, 1, 4, 4, 3, 4, 2, 1], | |
[2, 1, 0, 1, 1, 5, 12, 7, 7, 4, 0], | |
[8, 22, 4, 4, 3, 1, 1, 0, 1, 1, 1], | |
[0, 0, 0, 0, 0, 26, 15, 15, 7, 9, 0], | |
[10, 5, 3, 3, 3, 2, 2, 3, 1, 1, 1], | |
[0, 0, 0, 0, 0, 21, 9, 21, 13, 2, 0], | |
[13, 4, 7, 7, 5, 1, 1, 1, 1, 0, 1], | |
[12, 25, 6, 2, 6, 5, 1, 0, 0, 1, 1], | |
[17, 4, 1, 1, 1, 4, 4, 8, 1, 4, 0], | |
[0, 0, 0, 0, 0, 10, 22, 19, 5, 9, 0], | |
[2, 5, 4, 2, 2, 9, 2, 2, 1, 1, 0], | |
] | |
scaler = StandardScaler() | |
scaled_data = scaler.fit_transform(training) | |
X = [] | |
for t in training: | |
X.append(t[:-1]) | |
y = [last for *_, last in training] | |
param_rf = { | |
"n_estimators": [50, 100, 200], | |
"max_depth": [None, 5, 10], | |
"min_samples_split": [2, 5, 10], | |
} | |
param_dt = { | |
"max_depth": [None, 5, 10], | |
"min_samples_split": [2, 5, 10], | |
"min_samples_leaf": [1, 2, 4], | |
} | |
param_lr = {"penalty": ["l1", "l2"], "C": [0.001, 0.01, 0.1, 1, 10]} | |
param_ab = {"n_estimators": [50, 100, 200], "learning_rate": [0.001, 0.01, 0.1, 1, 10]} | |
param_gbc = {"n_estimators": [50, 100, 200], "learning_rate": [0.001, 0.01, 0.1, 1, 10]} | |
param_gnb = {"var_smoothing": [1e-9, 1e-8, 1e-7, 1e-6, 1e-5]} | |
param_knn = {"n_neighbors": [3, 5, 7, 9, 11]} | |
param_mlp = { | |
"hidden_layer_sizes": [(10,), (50,), (100,)], | |
"activation": ["tanh", "relu"], | |
"alpha": [0.0001, 0.001, 0.01], | |
} | |
param_svc = { | |
"C": [0.1, 1, 10, 100], | |
"gamma": [0.001, 0.01, 0.1, 1], | |
"kernel": ["linear", "rbf"], | |
} | |
param_qda = {"reg_param": [0.1, 0.5, 1.0]} | |
def fn_grid_search(model, params, X, y): | |
grid_search = GridSearchCV(model, params, cv=5) | |
grid_search.fit(X, y) | |
print(grid_search.best_score_, "Parameters:", grid_search.best_params_) | |
rfc = RandomForestClassifier() | |
dtc = DecisionTreeClassifier() | |
lr = LogisticRegression() | |
ab = AdaBoostClassifier() | |
gbc = GradientBoostingClassifier() | |
gnb = GaussianNB() | |
knn = KNeighborsClassifier() | |
mlp = MLPClassifier() | |
svc = SVC() | |
qda = QuadraticDiscriminantAnalysis() | |
fn_grid_search(rfc, param_rf, X, y) | |
fn_grid_search(dtc, param_dt, X, y) | |
# fn_grid_search(lr, param_lr, X, y) | |
fn_grid_search(ab, param_ab, X, y) | |
fn_grid_search(gbc, param_gbc, X, y) | |
fn_grid_search(gnb, param_gnb, X, y) | |
fn_grid_search(knn, param_knn, X, y) | |
# fn_grid_search(mlp, param_mlp, X, y) | |
fn_grid_search(svc, param_svc, X, y) | |
fn_grid_search(qda, param_qda, X, y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment