Skip to content

Instantly share code, notes, and snippets.

Created August 15, 2022 00:36
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cmparlettpelleriti/bd734748158d24e578f4daa9823f8c47 to your computer and use it in GitHub Desktop.
Save cmparlettpelleriti/bd734748158d24e578f4daa9823f8c47 to your computer and use it in GitHub Desktop.
from shiny import App, render, ui, reactive
from pathlib import Path
# Import modules for plot rendering
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
# Import modeling packages
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
def plotDecisionBoundary(Xdat,ydat, mod, model_type):
Xy = pd.concat([Xdat,ydat], axis = 1)
# grab range of values for plot
x0_range = np.linspace(-3,
3, num = 100)
x1_range = np.linspace(-3,
3, num = 100)
# get all possible points on graph
x0 = np.repeat(x0_range,1000)
x1 = np.tile(x1_range,1000)
x_grid = pd.DataFrame({Xdat.columns[0]: x0, Xdat.columns[1]: x1})
# predict all background points
p = mod.predict(x_grid)
x_grid["p"] = p
# plot
fig, ax = plt.subplots()
ys = np.unique(ydat)
colors = ["#264653", "#2a9d8f", "#e9c46a"]
palette = {ys[i]: colors[i] for i in range(len(ys))}
tt = model_type + " Decision Boundary"
bound = sns.scatterplot(x = Xdat.columns[0], y = Xdat.columns[1],
hue = "p", ax = ax, data = x_grid, alpha = 0.1,
s = 5, legend = False, palette = palette)
bound = sns.scatterplot(x = Xdat.columns[0], y = Xdat.columns[1],
hue = ydat, ax = ax, data = Xy, palette = palette)
bound.set(xlim = [-3,3], ylim = [-3, 3],
title = tt)
bound.legend(loc='center left', bbox_to_anchor=(1, 0.5))
def modeldata(Xdat,ydat,modeltype, depth = 10, k = 5):
if modeltype == "Logistic Regression":
mod = LogisticRegression()
elif modeltype == "Naive Bayes":
mod = GaussianNB()
elif modeltype == "Decision Tree":
mod = DecisionTreeClassifier(max_depth = depth)
elif modeltype == "KNN":
mod = KNeighborsClassifier(n_neighbors = k)
z = StandardScaler()
Xdat[Xdat.columns] = z.fit_transform(Xdat),ydat)
app_ui = ui.page_fluid(
ui.column(4,ui.input_select("dataset", "Choose a Data Set:",
["Palmer Penguins", "Iris", "Diabetes"]),),
ui.column(4,ui.input_slider("depth", "Max Depth", min = 1, max = 100, value = 10),),
ui.column(4,ui.input_slider("nneighbors", "Number of Neighbors", min = 1,
max = 100, value = 10),),
ui.input_action_button("go", "Create Decision Boundary"),
def server(input, output, session):
X = reactive.Value(pd.DataFrame())
y = reactive.Value(pd.Series())
def _():
if input.dataset() == "Palmer Penguins":
infile = Path(__file__).parent / "penguins.csv"
df = pd.read_csv(infile)
df = df[["bill_length_mm", "bill_depth_mm", "species"]]
df.dropna(inplace = True)
X.set(df[["bill_length_mm", "bill_depth_mm"]])
elif input.dataset() == "Diabetes":
infile = Path(__file__).parent / "diabetes.csv"
df = pd.read_csv(infile)
df = df[["Glucose", "BloodPressure", "Outcome"]]
df.dropna(inplace = True)
X.set(df[["Glucose", "BloodPressure"]])
elif input.dataset() == "Iris":
infile = Path(__file__).parent / "iris.csv"
df = pd.read_csv(infile)
df = df[["sepal_length", "sepal_width", "species"]]
df.dropna(inplace = True)
X.set(df[["sepal_length", "sepal_width"]])
def plot_lr():
mod = modeldata(X.get(), y.get(), "Logistic Regression")
f = plotDecisionBoundary(X.get(),y.get(), mod, "Logistic Regression")
def plot_nb():
mod = modeldata(X.get(), y.get(), "Naive Bayes")
f = plotDecisionBoundary(X.get(),y.get(), mod, "Naive Bayes")
def plot_dt():
mod = modeldata(X.get(), y.get(), "Decision Tree", depth = input.depth())
f = plotDecisionBoundary(X.get(),y.get(), mod, "Decision Tree")
def plot_knn():
mod = modeldata(X.get(), y.get(), "KNN", k = input.nneighbors())
f = plotDecisionBoundary(X.get(),y.get(), mod, "KNN")
app = App(app_ui, server, debug=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment