Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
from sklearn import datasets
from sklearn import metrics
from sklearn.naive_bayes import GaussianNB
def get_iris_dataset():
"""
Get the iris data set using sklearn library
:return: Dictionary-like object
"""
return datasets.load_iris()
def create_model(model_class, dataset):
"""
:param model_class: BaseNB class
:param dataset: Dictionary-like object
:return:
"""
mdl = model_class()
mdl.fit(dataset.data, dataset.target)
return mdl
def print_classification_report(expected, predicted):
print("Classification report")
print(metrics.classification_report(expected, predicted))
print_separator()
def print_separator():
print("=" * 30)
def print_confusion_matrix(expected, predicted):
print("Confusion matrix:")
print(metrics.confusion_matrix(expected, predicted))
print_separator()
ds = get_iris_dataset()
model = create_model(GaussianNB, ds)
expected = ds.target
predicted = model.predict(ds.data)
print_classification_report(expected, predicted)
print_confusion_matrix(expected, predicted)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.