Skip to content

Instantly share code, notes, and snippets.

@samirak93
Last active May 22, 2020 04:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samirak93/403e22943120b6e640037bfcdd33f6fd to your computer and use it in GitHub Desktop.
Save samirak93/403e22943120b6e640037bfcdd33f6fd to your computer and use it in GitHub Desktop.
from bokeh.io import output_file, show, output_notebook, save
from bokeh.models import (
ColumnDataSource, LabelSet, FactorRange, BasicTicker, ColorBar, LinearColorMapper, PrintfTickFormatter)
from bokeh.plotting import figure
from bokeh.palettes import Category20, Viridis256
from bokeh.layouts import column
import pandas as pd
import os
output_notebook()
path = os.path.dirname(os.getcwd())
def create_plots(df, chart_type: str, x_axis, y_axis: str, x_label: str,
y_label: str, plot_title='Plot', hmap_fill=None, show_plot=False):
chart_types = ['bar', 'scatter', 'bar_stack', 'heatmap']
if chart_type not in chart_types:
raise ValueError(
f'Given chart type {chart_type} is not one of {chart_types}. Please provide a valid value')
if chart_type == 'heatmap' and hmap_fill == None:
raise ValueError(
f'Chart is selected to be a Heatmap but "hmap_fill" is None. Please provide a valid column value.')
df_cols = df.columns.tolist()
if x_axis not in df_cols:
raise ValueError(
f'Given x-axis column {x_axis} is not a column in the dataframe provided. Please provide a valid column value.')
if chart_type != 'bar_stack' and y_axis not in df_cols:
raise ValueError(
f'Given y-axis column {y_axis} is not a column in the dataframe provided. Please provide a valid column value.')
if chart_type == 'bar_stack':
for cols in y_axis:
if cols not in df_cols:
raise ValueError(
f'Given y-axis column {cols} is not a column in the dataframe provided. Please provide a valid column value.')
if hmap_fill != None and hmap_fill not in df_cols:
raise ValueError(
f'Given hmap_fill column {hmap_fill} is not a column in the dataframe provided. Please provide a valid column value.')
if not isinstance(df, pd.DataFrame):
raise ValueError(
f'Input df should be a dataframe instead of {type(df)}.')
source = ColumnDataSource(df)
if chart_type == 'bar':
hover_tool_x = '@'+str(x_axis)
hover_tool_y = '@'+str(y_axis)
TOOLTIPS = [
(str(x_label), hover_tool_x),
(str(y_label), hover_tool_y),
]
x_range = df[x_axis].values
if chart_type == 'bar_stack':
if not isinstance(y_axis, list):
raise ValueError(
f'For a stacked bar chart y-axis must be a list of labels but {type(y_axis)} was provided.')
if len(y_axis) > 20:
raise ValueError(
f'The number of stacks are greater than 10. Please reduce the number of stacks.')
hover_tool_x = '@'+str(x_axis)
TOOLTIPS = [
(str(x_label), hover_tool_x)
]
for y in y_axis:
TOOLTIPS.append((y, '@'+y))
x_range = df[x_axis].values
if chart_type == 'scatter':
hover_tool_x = '@'+str(x_axis)
hover_tool_y = '@'+str(y_axis)
TOOLTIPS = [
(str(x_label), hover_tool_x),
(str(y_label), hover_tool_y),
]
y_range = (min(df[y_axis]), max(df[y_axis]))
x_range = (min(df[x_axis]), max(df[x_axis]))
if chart_type == 'heatmap':
hover_tool_x = '@'+str(x_axis)
hover_tool_y = '@'+str(y_axis)
hover_tool_z = '@'+str(hmap_fill)
TOOLTIPS = [
(str(x_label), hover_tool_x),
(str(y_label), hover_tool_y),
(str(hmap_fill), hover_tool_z),
]
x_range = list(reversed(df[x_axis].unique().tolist()))
y_range = df[y_axis].unique().tolist()
mapper = LinearColorMapper(
palette=Viridis256, low=df[hmap_fill].min(), high=df[hmap_fill].max())
# Plot starts here
tools=['hover','save']
if chart_type == 'bar':
plot = figure(x_range=x_range, plot_width=750, plot_height=550, tools=tools,
title=plot_title, tooltips=TOOLTIPS)
plot.vbar(x=x_axis, top=y_axis, width=0.7,
source=source, color='dodgerblue', fill_alpha=0.7)
labels = LabelSet(x=x_axis, y=y_axis, text=y_axis, level='glyph',
x_offset=-9, y_offset=5, source=source, render_mode='canvas',
text_font_size='13px')
plot.add_layout(labels)
plot.xgrid.grid_line_color = None
if chart_type == 'bar_stack':
plot = figure(x_range=x_range, plot_width=750, plot_height=550, tools=tools,
title=plot_title, tooltips=TOOLTIPS)
plot.vbar_stack(y_axis, x=x_axis, width=0.7, alpha=0.5, color=Category20[20][:len(y_axis)], source=source,
legend_label=y_axis)
plot.x_range.range_padding = 0.1
plot.xgrid.grid_line_color = None
if chart_type == 'scatter':
plot = figure(plot_width=750, plot_height=550, tools=tools,
title=plot_title, tooltips=TOOLTIPS)
plot.circle(x=x_axis, y=y_axis, size=10, source=source,
color='dodgerblue', alpha=0.5)
plot.outline_line_color = None
plot.xgrid.grid_line_color = None
plot.ygrid.grid_line_color = None
if chart_type == 'heatmap':
plot = figure(x_range=x_range, y_range=y_range, plot_width=750, plot_height=550, tools=tools,
title=plot_title, x_axis_location="above", tooltips=TOOLTIPS)
plot.rect(y=y_axis, x=x_axis, width=1, height=1,
source=source,
fill_color={'field': hmap_fill, 'transform': mapper},
line_color=None)
color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size="12px",
label_standoff=6, border_line_color=None, location=(0, 0))
plot.add_layout(color_bar, 'right')
plot.grid.grid_line_color = None
plot.axis.axis_line_color = None
plot.axis.major_tick_line_color = None
plot.axis.major_label_standoff = 0
plot.axis.major_label_text_font_size = "10px"
plot.xaxis.axis_label = x_label
plot.yaxis.axis_label = y_label
plot.title.align = 'center'
if show_plot:
show(plot)
return plot
def save_plots(plots, save_plot = False):
if save_plot:
output_file(path+"/Crosses/plots/plot.html")
save(plots)
show(plots)
fruits = ['Apples', 'Pears', 'Nectarines', 'Plums', 'Grapes', 'Strawberries']
counts = [5, 3, 4, 2, 4, 6]
bar_plot_df = pd.DataFrame(data= {'fruits':fruits, 'counts':counts})
bar_plot = create_plots(bar_plot_df, chart_type='bar', x_axis='fruits', y_axis='counts',
x_label='Fruits', y_label='Count',
plot_title='Sample Bar Graph', show_plot = True)
math = [96, 93, 44, 62, 54, 67]
science = [55, 43, 34, 22, 44, 56]
scatter_plot_df = pd.DataFrame(data= {'math':math, 'science':science})
scatter_plot = create_plots(scatter_plot_df, chart_type='scatter', x_axis='math', y_axis='science',
x_label='Math', y_label='Science',
plot_title='Sample Scatter Graph', show_plot = True)
fruits = ['Apples', 'Pears', 'Nectarines', 'Plums', 'Grapes', 'Strawberries']
years = ["2015", "2016", "2017"]
colors = ["#c9d9d3", "#718dbf", "#e84d60"]
stack_bar_df = pd.DataFrame(data = {'fruits' : fruits,
'2015' : [2, 1, 4, 3, 2, 4],
'2016' : [5, 3, 4, 2, 4, 6],
'2017' : [3, 2, 4, 4, 5, 3]})
stack_bar_plot = create_plots(stack_bar_df, chart_type='bar_stack',
x_axis='fruits', y_axis=['2015', '2016','2017'],
x_label='Fruits', y_label='Count',
plot_title='Stacked Bar Graph', show_plot=True)
player_from = ['Gerrard','Torres','Alonso','Alonso','Reina']
player_to = ['Torres','Gerrard','Reina','Gerrard','Torres']
total_passes = [120, 45, 55, 97, 104]
heat_map_df = pd.DataFrame(data = {'player_from':player_from,
'player_to':player_to,
'total_passes':total_passes})
heat_map_plot = create_plots(heat_map_df, chart_type='heatmap', x_axis='player_from', y_axis='player_to',
x_label='Pass From', y_label='Pass To', hmap_fill='total_passes',
plot_title='Heat Map Graph',show_plot=True)
plots = column(bar_plot, scatter_plot, stack_bar_plot, heat_map_plot)
save_plots(plots=plots, save_plot=False)
@samirak93
Copy link
Author

Sample outputs:

Bar Graph

bokeh_plot

Scatter Graph

bokeh_plot (1)

Stacked Bar Graph

bokeh_plot (2)

Heat Map Graph

bokeh_plot (3)

Save Plots

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment