Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save grantmwilliams/4da13a2f6ac85f9c5e70571a07c43ded to your computer and use it in GitHub Desktop.
Save grantmwilliams/4da13a2f6ac85f9c5e70571a07c43ded to your computer and use it in GitHub Desktop.
""" This file uses a decision tree to classify the input data
and then loads the appropriate auxiliary dataset from the classification result
Uses SKlearn's decision tree implementation and as an example the
SKlearn iris dataset
"""
import sys
import numpy as np # only used to create our example datasets.
import pandas as pd # only used to create our example datasets.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
def create_example_datasets(iris):
data = iris["data"]
categories = iris["target"]
# create dataframes of only the rows that have a specific flower and give the columns names
iris_setosa = pd.DataFrame(data=data[categories == 0, :], columns=iris["feature_names"])
iris_versicolor = pd.DataFrame(data=data[categories == 1, :], columns=iris["feature_names"])
iris_virginica= pd.DataFrame(data=data[categories == 2, :], columns=iris["feature_names"])
# write the flower specific dataframes to csv files
iris_setosa.to_csv("datasets/setosa.csv")
iris_versicolor.to_csv("datasets/versicolor.csv")
iris_virginica.to_csv("datasets/virginica.csv")
def load_dataset(prediction):
""" given a prediction return the appropriate dataset
returns a file path to a new dataset
"""
# create a mapping from classification to dataset file path
data_set_mapping = {
"virginica": "virginica.csv",
"setosa": "setosa.csv",
"versicolor": "versicolor.csv"
}
file_path = "datasets"
file_name = data_set_mapping[prediction]
# now you would want to implement the logic to load the appropriate dataset here
# as an example I use one of the csv files we created from the subset of the iris dataset
df = pd.read_csv(file_path + "/" + file_name)
# drop the unnamed column from our csv
df = df.loc[:, ~df.columns.str.contains("^Unnamed")]
return df
# load the dataset we are going to use in the decision tree
iris = load_iris()
# uses our iris dataset to create the datasets we want to load after our decision tree is evluated
create_example_datasets(iris)
X = iris["data"]
y = iris["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y)
model = DecisionTreeClassifier()
model.fit(X_train, y_train)
print("Model Accuracy: {}%".format(model.score(X_test, y_test)*100))
y_predict = model.predict(X_test)
# use data set to get the names of the flowers from our predictions
predicted_names = [iris["target_names"][i] for i in y_predict]
# as an example we load the dataset for the first predcited name
print("Loading Dataset for {}".format(predicted_names[0]))
output_dataset = load_dataset(predicted_names[0])
# temporarily prints all columns on the same line
with pd.option_context("expand_frame_repr", False):
print(output_dataset.head(5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment