Created
October 22, 2019 14:20
-
-
Save JohnDeJesus22/63d8750254efc862d1d3e27e0209280a to your computer and use it in GitHub Desktop.
Class to create categories for TeacherBoard web app feature.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.cluster import KMeans | |
import json | |
import plotly | |
import plotly.graph_objs as go | |
class KmeansGrouper: | |
# may need a default about of clusters (maybe 2?) since a convergence warning due to duplicates values | |
# ex: only found 2 when we requested 4. | |
def __init__(self, clusters): | |
self.clusters = clusters | |
self.kmeans_obj = KMeans(n_clusters=self.clusters, init='k-means++', max_iter=300, n_init=10, random_state=0) | |
def scatter_plot(self, kmeans_data, groups): | |
''' | |
Create the plot for the kmeans grouping to display in the kmeans dashboard | |
:param kmeans_data: data with groupings attached | |
:param groups: list of groupings from kmeans | |
:return: graphjson with chart to plot | |
''' | |
# Get length of unique groups | |
groups_set_length = len(set(groups)) | |
# create traces for scatter plot (+1 for categories to start at 1) | |
trace = [self._scatterplot_maker(i, kmeans_data, groups) for i in range(1, groups_set_length+1)] | |
# create figure | |
scatter_plot = dict(data=trace, | |
layout=dict(title='<b>Categories Plot</b>', | |
xaxis=dict(title='Column 1'), | |
yaxis=dict(title='Column 2'), | |
hovermode='closest' | |
) | |
) | |
graphJSON = json.dumps(scatter_plot, cls=plotly.utils.PlotlyJSONEncoder) | |
return graphJSON | |
def _scatterplot_maker(self, group_label, data_np, groups): | |
# create traces | |
trace = go.Scatter( | |
x=data_np[groups == group_label, 0], | |
y=data_np[groups == group_label, 1], | |
name=str(group_label), | |
mode='markers', | |
marker=dict(size=12, | |
line=dict( | |
color='black', | |
width=1 | |
)) | |
) | |
return trace | |
def kmeans_categories(self, data): | |
''' | |
Apply kmeans to data with 2 columns selected by the user | |
:param data: dataframe from 2 columns entered by the user | |
:return: dataframe with kmeans categories created | |
''' | |
data_np = data.values | |
categories = self.kmeans_obj.fit_predict(data_np) | |
categories = categories + 1 | |
data['Category'] = categories | |
return data, categories, data_np | |
def dataframe_constructors(self, df_with_groupings): | |
''' | |
Creates list of dataframes split by the categories | |
:param df_with_groupings: dataframe with groupings from kmeans | |
:return: list of dataframes equal to the number of categories | |
''' | |
categories = sorted(df_with_groupings.Category.unique().tolist()) | |
dataframes = [df_with_groupings[df_with_groupings.Category == category] for category in categories] | |
return dataframes | |
def dfs_to_html(self, dataframe_list): | |
''' | |
Create html versions of dataframes for display | |
:param dataframe_list: list of dataframes from dataframes_constructors function | |
:return: list of tables to display on html in tables attribute of render_template | |
''' | |
tables = [df.to_html(classes=f'Category {i}') for i, df in enumerate(dataframe_list)] | |
return tables |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment