Skip to content

Instantly share code, notes, and snippets.

@mardani72
Last active October 11, 2020 16:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mardani72/24a4ce6f940dc1c8ca60766489b32b13 to your computer and use it in GitHub Desktop.
Save mardani72/24a4ce6f940dc1c8ca60766489b32b13 to your computer and use it in GitHub Desktop.
12_baseline_model
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import ExtraTreesClassifier
# define Classifiers
log = LogisticRegression()
knn = KNeighborsClassifier()
dtree = DecisionTreeClassifier()
rtree = RandomForestClassifier()
svm = SVC()
nb = GaussianNB()
gbc = GradientBoostingClassifier()
etree = ExtraTreesClassifier()
# define a function that uses pipeline to impelement data transformation and fit with model then cross validate
def baseline_model(model_name):
model = model_name
steps = list()
steps.append(('ss', StandardScaler() ))
steps.append(('ml', model))
pipeline = Pipeline(steps=steps)
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
# balanced X,y from SMOTE can also be used
scores = cross_val_score(pipeline, X_sm, y_sm, scoring='accuracy', cv=cv, n_jobs=-1)
print(model,'Accuracy: %.3f' % (mean(scores)))
#Run Function
baseline_model(log)
baseline_model(knn)
baseline_model(dtree)
baseline_model(rtree)
baseline_model(svm)
baseline_model(nb)
baseline_model(gbc)
baseline_model(etree)
#LogisticRegression() Accuracy: 0.623
#KNeighborsClassifier() Accuracy: 0.880
#DecisionTreeClassifier() Accuracy: 0.845
#RandomForestClassifier() Accuracy: 0.910
#SVC() Accuracy: 0.777
#GaussianNB() Accuracy: 0.357
#GradientBoostingClassifier() Accuracy: 0.832
#ExtraTreesClassifier() Accuracy: 0.934
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment