Last active
May 22, 2020 04:46
-
-
Save samirak93/403e22943120b6e640037bfcdd33f6fd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from bokeh.io import output_file, show, output_notebook, save | |
from bokeh.models import ( | |
ColumnDataSource, LabelSet, FactorRange, BasicTicker, ColorBar, LinearColorMapper, PrintfTickFormatter) | |
from bokeh.plotting import figure | |
from bokeh.palettes import Category20, Viridis256 | |
from bokeh.layouts import column | |
import pandas as pd | |
import os | |
output_notebook() | |
path = os.path.dirname(os.getcwd()) | |
def create_plots(df, chart_type: str, x_axis, y_axis: str, x_label: str, | |
y_label: str, plot_title='Plot', hmap_fill=None, show_plot=False): | |
chart_types = ['bar', 'scatter', 'bar_stack', 'heatmap'] | |
if chart_type not in chart_types: | |
raise ValueError( | |
f'Given chart type {chart_type} is not one of {chart_types}. Please provide a valid value') | |
if chart_type == 'heatmap' and hmap_fill == None: | |
raise ValueError( | |
f'Chart is selected to be a Heatmap but "hmap_fill" is None. Please provide a valid column value.') | |
df_cols = df.columns.tolist() | |
if x_axis not in df_cols: | |
raise ValueError( | |
f'Given x-axis column {x_axis} is not a column in the dataframe provided. Please provide a valid column value.') | |
if chart_type != 'bar_stack' and y_axis not in df_cols: | |
raise ValueError( | |
f'Given y-axis column {y_axis} is not a column in the dataframe provided. Please provide a valid column value.') | |
if chart_type == 'bar_stack': | |
for cols in y_axis: | |
if cols not in df_cols: | |
raise ValueError( | |
f'Given y-axis column {cols} is not a column in the dataframe provided. Please provide a valid column value.') | |
if hmap_fill != None and hmap_fill not in df_cols: | |
raise ValueError( | |
f'Given hmap_fill column {hmap_fill} is not a column in the dataframe provided. Please provide a valid column value.') | |
if not isinstance(df, pd.DataFrame): | |
raise ValueError( | |
f'Input df should be a dataframe instead of {type(df)}.') | |
source = ColumnDataSource(df) | |
if chart_type == 'bar': | |
hover_tool_x = '@'+str(x_axis) | |
hover_tool_y = '@'+str(y_axis) | |
TOOLTIPS = [ | |
(str(x_label), hover_tool_x), | |
(str(y_label), hover_tool_y), | |
] | |
x_range = df[x_axis].values | |
if chart_type == 'bar_stack': | |
if not isinstance(y_axis, list): | |
raise ValueError( | |
f'For a stacked bar chart y-axis must be a list of labels but {type(y_axis)} was provided.') | |
if len(y_axis) > 20: | |
raise ValueError( | |
f'The number of stacks are greater than 10. Please reduce the number of stacks.') | |
hover_tool_x = '@'+str(x_axis) | |
TOOLTIPS = [ | |
(str(x_label), hover_tool_x) | |
] | |
for y in y_axis: | |
TOOLTIPS.append((y, '@'+y)) | |
x_range = df[x_axis].values | |
if chart_type == 'scatter': | |
hover_tool_x = '@'+str(x_axis) | |
hover_tool_y = '@'+str(y_axis) | |
TOOLTIPS = [ | |
(str(x_label), hover_tool_x), | |
(str(y_label), hover_tool_y), | |
] | |
y_range = (min(df[y_axis]), max(df[y_axis])) | |
x_range = (min(df[x_axis]), max(df[x_axis])) | |
if chart_type == 'heatmap': | |
hover_tool_x = '@'+str(x_axis) | |
hover_tool_y = '@'+str(y_axis) | |
hover_tool_z = '@'+str(hmap_fill) | |
TOOLTIPS = [ | |
(str(x_label), hover_tool_x), | |
(str(y_label), hover_tool_y), | |
(str(hmap_fill), hover_tool_z), | |
] | |
x_range = list(reversed(df[x_axis].unique().tolist())) | |
y_range = df[y_axis].unique().tolist() | |
mapper = LinearColorMapper( | |
palette=Viridis256, low=df[hmap_fill].min(), high=df[hmap_fill].max()) | |
# Plot starts here | |
tools=['hover','save'] | |
if chart_type == 'bar': | |
plot = figure(x_range=x_range, plot_width=750, plot_height=550, tools=tools, | |
title=plot_title, tooltips=TOOLTIPS) | |
plot.vbar(x=x_axis, top=y_axis, width=0.7, | |
source=source, color='dodgerblue', fill_alpha=0.7) | |
labels = LabelSet(x=x_axis, y=y_axis, text=y_axis, level='glyph', | |
x_offset=-9, y_offset=5, source=source, render_mode='canvas', | |
text_font_size='13px') | |
plot.add_layout(labels) | |
plot.xgrid.grid_line_color = None | |
if chart_type == 'bar_stack': | |
plot = figure(x_range=x_range, plot_width=750, plot_height=550, tools=tools, | |
title=plot_title, tooltips=TOOLTIPS) | |
plot.vbar_stack(y_axis, x=x_axis, width=0.7, alpha=0.5, color=Category20[20][:len(y_axis)], source=source, | |
legend_label=y_axis) | |
plot.x_range.range_padding = 0.1 | |
plot.xgrid.grid_line_color = None | |
if chart_type == 'scatter': | |
plot = figure(plot_width=750, plot_height=550, tools=tools, | |
title=plot_title, tooltips=TOOLTIPS) | |
plot.circle(x=x_axis, y=y_axis, size=10, source=source, | |
color='dodgerblue', alpha=0.5) | |
plot.outline_line_color = None | |
plot.xgrid.grid_line_color = None | |
plot.ygrid.grid_line_color = None | |
if chart_type == 'heatmap': | |
plot = figure(x_range=x_range, y_range=y_range, plot_width=750, plot_height=550, tools=tools, | |
title=plot_title, x_axis_location="above", tooltips=TOOLTIPS) | |
plot.rect(y=y_axis, x=x_axis, width=1, height=1, | |
source=source, | |
fill_color={'field': hmap_fill, 'transform': mapper}, | |
line_color=None) | |
color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size="12px", | |
label_standoff=6, border_line_color=None, location=(0, 0)) | |
plot.add_layout(color_bar, 'right') | |
plot.grid.grid_line_color = None | |
plot.axis.axis_line_color = None | |
plot.axis.major_tick_line_color = None | |
plot.axis.major_label_standoff = 0 | |
plot.axis.major_label_text_font_size = "10px" | |
plot.xaxis.axis_label = x_label | |
plot.yaxis.axis_label = y_label | |
plot.title.align = 'center' | |
if show_plot: | |
show(plot) | |
return plot | |
def save_plots(plots, save_plot = False): | |
if save_plot: | |
output_file(path+"/Crosses/plots/plot.html") | |
save(plots) | |
show(plots) | |
fruits = ['Apples', 'Pears', 'Nectarines', 'Plums', 'Grapes', 'Strawberries'] | |
counts = [5, 3, 4, 2, 4, 6] | |
bar_plot_df = pd.DataFrame(data= {'fruits':fruits, 'counts':counts}) | |
bar_plot = create_plots(bar_plot_df, chart_type='bar', x_axis='fruits', y_axis='counts', | |
x_label='Fruits', y_label='Count', | |
plot_title='Sample Bar Graph', show_plot = True) | |
math = [96, 93, 44, 62, 54, 67] | |
science = [55, 43, 34, 22, 44, 56] | |
scatter_plot_df = pd.DataFrame(data= {'math':math, 'science':science}) | |
scatter_plot = create_plots(scatter_plot_df, chart_type='scatter', x_axis='math', y_axis='science', | |
x_label='Math', y_label='Science', | |
plot_title='Sample Scatter Graph', show_plot = True) | |
fruits = ['Apples', 'Pears', 'Nectarines', 'Plums', 'Grapes', 'Strawberries'] | |
years = ["2015", "2016", "2017"] | |
colors = ["#c9d9d3", "#718dbf", "#e84d60"] | |
stack_bar_df = pd.DataFrame(data = {'fruits' : fruits, | |
'2015' : [2, 1, 4, 3, 2, 4], | |
'2016' : [5, 3, 4, 2, 4, 6], | |
'2017' : [3, 2, 4, 4, 5, 3]}) | |
stack_bar_plot = create_plots(stack_bar_df, chart_type='bar_stack', | |
x_axis='fruits', y_axis=['2015', '2016','2017'], | |
x_label='Fruits', y_label='Count', | |
plot_title='Stacked Bar Graph', show_plot=True) | |
player_from = ['Gerrard','Torres','Alonso','Alonso','Reina'] | |
player_to = ['Torres','Gerrard','Reina','Gerrard','Torres'] | |
total_passes = [120, 45, 55, 97, 104] | |
heat_map_df = pd.DataFrame(data = {'player_from':player_from, | |
'player_to':player_to, | |
'total_passes':total_passes}) | |
heat_map_plot = create_plots(heat_map_df, chart_type='heatmap', x_axis='player_from', y_axis='player_to', | |
x_label='Pass From', y_label='Pass To', hmap_fill='total_passes', | |
plot_title='Heat Map Graph',show_plot=True) | |
plots = column(bar_plot, scatter_plot, stack_bar_plot, heat_map_plot) | |
save_plots(plots=plots, save_plot=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Sample outputs:
Bar Graph
Scatter Graph
Stacked Bar Graph
Heat Map Graph
Save Plots