Skip to content

Instantly share code, notes, and snippets.

@schaunwheeler
Last active May 8, 2018 17:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save schaunwheeler/9a98d8bee5f039e9872c76fb24a6e69c to your computer and use it in GitHub Desktop.
Save schaunwheeler/9a98d8bee5f039e9872c76fb24a6e69c to your computer and use it in GitHub Desktop.
Example of workflow automation using Bokeh. Copyright (c) 2018 Valassis Digital under the terms of the BSD 3-Clause license.
"""
Code described in Medium post https://medium.com/@schaun.wheeler/codify-your-workflow-377f5f8bf4c3
Copyright (c) 2018 Valassis Digital under the terms of the BSD 3-Clause license.
"""
from math import sin, pi, log, floor
from shapely.wkt import loads, dumps
from shapely.geometry import Polygon, Point, MultiPolygon, LineString, MultiLineString, MultiPoint, GeometryCollection
from bokeh.io import output_notebook, show, output_file, reset_output, save
from bokeh.models import GMapPlot, GMapOptions, ColumnDataSource, Range1d
from bokeh.models.tools import HoverTool, WheelZoomTool, ResetTool, PanTool
from bokeh.models.renderers import GlyphRenderer
from bokeh.resources import CDN
from os import getcwd
from pandas import Series, DataFrame
from numpy import mean
from geohash2 import decode_exactly, decode, encode
class geoPlotter(object):
"""
This class provides a wrapper to produce most of the boilerplate needed to use Bokeh to plot on
top of Google Maps images.
Example usage:
geohashes = [
'dnrgrfm', 'dnrgrf3', 'dnrgrf7', 'dnrgrf5', 'dnrgrfk', 'dnrgrf1',
'dnrgrfh', 'dnrgrf2', 'dnrgrcu', 'dnrgrfj', 'dnrgrcv', 'dnrgrf4'
]
API_KEY = (GET API KEY AT developers.google.com/maps/documentation/javascript/get-api-key)
gp = geoPlotter(api_key=API_KEY)
gp.add_source(geohashes, label='a', order=range(len(geohashes)))
gp.prepare_plot(plot_width=700)
gp.add_layer('a', Patches, color='yellow', alpha=0.5)
gp.add_layer('a', Circle, color='blue', size=20, alpha=0.75)
gp.add_layer('a', Text, text='order', text_color='white', text_font_size='10pt', text_align='center', text_baseline='middle')
gp.render_plot('auto')
"""
def __init__(self, api_key):
"""
To initialize the class, you must include a Google Maps API key, which can be obtained at
developers.google.com/maps/documentation/javascript/get-api-key
"""
self.api_key = api_key
self.sources = dict()
self.xmin = None
self.ymin = None
self.xmax = None
self.ymax = None
self.plot = None
@staticmethod
def _validate_geohash(value):
"""Check whether an input value is a valid geohash"""
if len(value) > 12:
return False
base32 = '0123456789bcdefghjkmnpqrstuvwxyz'
return all(v in base32 for v in value)
@staticmethod
def _validate_wellknowntext(value):
"""Check whether an input value is valid well-known text"""
valid_types = ['GEOMETRY', 'POINT', 'MULTIPOINT', 'LINESTRING', 'MULTILINESTRING', 'POLYGON', 'MULTIPOLYGON']
return any(value.startswith(vt) for vt in valid_types)
@staticmethod
def _validate_shapelyobject(value):
"""Check whether an input value is a valid shapely object."""
return issubclass(type(value), (Polygon, Point, MultiPolygon, LineString, MultiLineString, MultiPoint, GeometryCollection))
@staticmethod
def _geohash_to_coords(value):
"""Convert a geohash to the x and y values of that geohash's bounding box."""
y, x, y_margin, x_margin = decode_exactly(value)
x_coords = [
x - x_margin,
x - x_margin,
x + x_margin,
x + x_margin
]
y_coords = [
y + y_margin,
y - y_margin,
y - y_margin,
y + y_margin
]
return [x_coords], [y_coords]
@staticmethod
def _shape_to_coords(value, wkt=False):
"""Convert a shape (a shapely object or well-known text) to x and y coordinates."""
x_coords = list()
y_coords = list()
if wkt:
value = loads(value)
if hasattr(value, '__len__'):
for v in value:
if not hasattr(v, 'exterior'):
v = v.buffer(0)
x_coords.append(v.exterior.coords.xy[0].tolist())
y_coords.append(v.exterior.coords.xy[1].tolist())
else:
if not hasattr(value, 'exterior'):
value = value.buffer(0)
x_coords.append(value.exterior.coords.xy[0].tolist())
y_coords.append(value.exterior.coords.xy[1].tolist())
return x_coords, y_coords
@staticmethod
def _lat_rad(lat):
"Convert a latitude to radians (for estimating zoom factor for Google Maps)."
sine = sin(lat * pi / 180.);
rad_x2 = log((1 + sine) / (1 - sine)) / 2.
return max(min(rad_x2, pi), -pi) / 2.
@staticmethod
def _zoom(map_px, world_px, fraction):
"""Calculate the zoom factor for Google Maps."""
try:
return floor(log(map_px / world_px / fraction) / log(2));
except ZeroDivisionError:
return None
def _estimate_zoom_level(self, plot_width, height_max=600, zoom_max=21, world_dim=256):
"""
Given a desired plot width and data source(s), estimate the best zoom factor for Google Maps.
Parameters:
plot_width (int): desired plot width, in pixels
height_max (int): maximum height allowable for the plot
zoom_max (int): maximum zoom factor (21 is the maximum Google Maps allows)
world_dim (int): the number of dimensions fo the world map (256 is the Google default)
Returns:
Zoom factor (int): the zoom factor to be used in the call to Google Maps
y_center (float): the central latitude for the plot
x_center (float): the central longitude for the plot
plot_height (int): the height to be used for the plot
"""
yrange = abs(self.ymax - self.ymin)
xrange = abs(self.xmax - self.xmin)
if (xrange == 0) or (yrange == 0):
plot_height = height_max
else:
plot_height=int((plot_width / xrange) * yrange)
if plot_height < plot_width:
plot_height = plot_width
if plot_height > height_max:
plot_height = height_max
y_center = (self.ymin + self.ymax) / 2.0
x_center = (self.xmin + self.xmax) / 2.0
lat_fraction = (self._lat_rad(self.ymax) - self._lat_rad(self.ymin)) / pi
lng_diff = self.xmax - self.xmin
lng_fraction = ((lng_diff + 360) if (lng_diff < 0) else lng_diff) / 360.
lat_zoom = self._zoom(plot_height, world_dim, lat_fraction)
lng_zoom = self._zoom(plot_width, world_dim, lng_fraction)
if (lat_zoom is None) and (lng_zoom is None):
return zoom_max, y_center, x_center, plot_height
elif lat_zoom is None:
return int(min(lng_zoom, zoom_max)), y_center, x_center, plot_height
elif lng_zoom is None:
return int(min(lat_zoom, zoom_max)), y_center, x_center, plot_height
else:
return int(min(lat_zoom, lng_zoom, zoom_max)), y_center, x_center, plot_height
def _process_input_value(self, value):
"""Router function for values: take an arbitrary value, validate, and return standardized coordinate output."""
if type(value) == str:
if self._validate_geohash(value):
return self._geohash_to_coords(value)
elif self._validate_wellknowntext(value):
return self._shape_to_coords(value, wkt=True)
else:
raise ValueError('Unparseable string input.')
elif type(value) in (tuple, list):
if len(value) == 2:
if all(type(v) == float for v in value):
return tuple([[v]] for v in value)
else:
raise ValueError('Unparseable list input.')
else:
raise ValueError('List contains too many elements.')
elif self._validate_shapelyobject(value):
return self._shape_to_coords(value, wkt=False)
else:
raise ValueError('Unrecognizeable input.')
def _set_coordinate_bounds(self, df):
"""Given a new source, set or update coordinate bounds for the plot."""
xmin = df['x_coords'].apply(min).min()
xmax = df['x_coords'].apply(max).max()
ymin = df['y_coords'].apply(min).min()
ymax = df['y_coords'].apply(max).max()
if self.xmin is not None:
if xmin < self.xmin:
self.xmin = xmin
else:
self.xmin = xmin
if self.xmax is not None:
if xmax > self.xmax:
self.xmax = xmax
else:
self.xmax = xmax
if self.ymin is not None:
if ymin < self.ymin:
self.ymin = ymin
else:
self.ymin = ymin
if self.ymax is not None:
if ymax < self.ymax:
self.ymax = ymax
else:
self.ymax = ymax
def add_source(self, data, label, column_name=None, **kwargs):
"""
Add a source to the `self.sources`, to be used in the plot.
Parameters:
data (list or DataFrame): a list of geohashes, shapely objects, well-known text strings, or longitude/latitude pairs;
or a DataFrame where the objects to be plotted are indicated by `column_name`
label (str): a name that can be used to reference the source going forward
column_name (str): optional, the column of `data` that contains geohashes, shapely objects, well-known text strings,
or longitude/latitude pairs
kwargs: lists of the same length as `data`. These will be appended to the data as metadata.
"""
# process data into x and y coordinates
if type(data) == DataFrame:
if column_name is None:
raise ValueError('If data is a dataframe then column_name must be specified.')
df = data.copy()
keys = [c for c in df.columns if c != column_name]
else:
df = Series(data).to_frame('raw_data')
keys = [k for k in kwargs.keys()]
column_name = 'raw_data'
df['processed_data'] = df[column_name].apply(self._process_input_value)
df['x_coords'] = df['processed_data'].apply(lambda v: v[0])
df['y_coords'] = df['processed_data'].apply(lambda v: v[1])
df = df.drop([column_name, 'processed_data'], axis=1)
# add kwargs as metadata
vals = ['x_coords', 'y_coords']
if len(keys) > 0:
df2 = DataFrame(columns=keys + ['i'] + vals).set_index(keys + ['i'])
else:
df2 = DataFrame(columns=keys + ['i'] + vals).set_index('i', append=True)
if len(kwargs) > 0:
for k, v in kwargs.items():
df[k] = v
df = df.set_index(keys)
# in cases where data inputs are multipolygons, split out into separate rows with metadata correctly associated
for ind, row in df.iterrows():
nrep = len(row['x_coords'])
nrep_check = len(row['y_coords'])
if nrep != nrep_check:
raise AssertionError('X and Y coordinate lists are not equal in length.')
for i in range(nrep):
ind2 = (ind, i,) if type(ind) != tuple else ind + (i,)
df2.loc[ind2, ['x_coords', 'y_coords']] = None, None
df2.loc[ind2, ['x_coords', 'y_coords']] = row['x_coords'][i], row['y_coords'][i]
self._set_coordinate_bounds(df2)
self.sources[label] = ColumnDataSource(df2.reset_index().drop(['i'], axis=1))
def prepare_plot(self, plot_width=700, plot_height=None, zoom=None, map_type='hybrid', **kwargs):
"""
Create the actual plot object (stored in `self.plot`).
Parameters:
plot_width (int): desired plot width, in pixels
plot_height (int): desired plot height, will be calculated automatically if not supplied
zoom (int): zoom factor for Google Maps, will be calculated automatically if not supplied
map_type (string): 'satellite', 'roadmap', or 'hybrid'
kwargs: any options passed to Bokeh GMapPlot (title, etc.)
"""
zoom_level, lat_center, lng_center, auto_plot_height = self._estimate_zoom_level(plot_width)
if plot_height is None:
plot_height = auto_plot_height
if zoom is None:
zoom = zoom_level
map_options = GMapOptions(lat=lat_center, lng=lng_center, map_type=map_type, zoom=zoom)
self.plot = GMapPlot(
x_range=Range1d(),
y_range=Range1d(),
map_options=map_options,
plot_width=plot_width,
plot_height=plot_height,
**kwargs
)
self.plot.api_key = self.api_key
self.plot.add_tools(WheelZoomTool(), ResetTool(), PanTool())
def add_layer(self, source_label, bokeh_model, tooltips=None, **kwargs):
"""
Add bokeh models (glyphs or markers) to `self.plot`.
`self.prepare_plot` must have been called previous to this.
Parameters:
source_label (str): string corresponding to a label previously called in `self.add_source`
bokeh_model: any Bokeh model or glyph class
tooltips: string or list of tuples (passed to Bokeh HoverTool)
kwargs: options passed to the objected for `bokeh_model`
This method allows two special kwargs: 'color' and 'alpha'. When used with a bokeh model that has 'fill_color'
and 'line_color' and 'fill_alpha' and 'line_alpha' properties, calling the special kwarg will use the same value
for both.
"""
if self.plot is None:
raise AssertionError('self.plot is null; call `self.prepare_plot`.')
if 'color' in kwargs.keys():
color = kwargs.pop('color')
kwargs['fill_color'] = color
kwargs['line_color'] = color
if 'alpha' in kwargs.keys():
alpha = kwargs.pop('alpha')
kwargs['fill_alpha'] = alpha
kwargs['line_alpha'] = alpha
try:
model_object = bokeh_model(xs='x_coords', ys='y_coords', name=source_label, **kwargs)
except AttributeError:
model_object = bokeh_model(x='x_coords_point', y='y_coords_point', name=source_label, **kwargs)
if not all(c in self.sources[source_label].column_names for c in ['x_coord_point', 'y_coord_point']):
self.sources[source_label].data['x_coords_point'] = [mean(x) for x in self.sources[source_label].data['x_coords']]
self.sources[source_label].data['y_coords_point'] = [mean(x) for x in self.sources[source_label].data['y_coords']]
self.sources[source_label].column_names.extend(['x_coords_point', 'y_coords_point'])
rend = self.plot.add_glyph(self.sources[source_label], model_object)
if tooltips is not None:
self.plot.add_tools(HoverTool(tooltips=tooltips, renderers=[rend]))
def render_plot(self, display_type='object'):
"""
Pull everything together into a plot ready for display.
Parameters:
display_plot (str): either 'object', 'auto', 'notebook', or 'file'. If 'object', it returns
the plot object. If 'notebook', the plot is displayed in the notebok. If 'file', the plot
is saved to 'plot.html' in the current working directory. If 'auto', then it chooses to
plot in notebook or file (or not plot at all) depending on how many data points the plot has.
This can prevent the plot from freezing up a notebook.
"""
n_points = 0
for v in self.sources.values():
try:
n_points += sum(len(x) for x in v.data['x_coords'])
except TypeError:
n_points += len(v.data)
if display_type == 'auto':
if n_points > 100000:
save(self.plot, filename='plot.html', resources=CDN, title='plot')
return 'plot.html'
elif n_points > 1000000:
raise ValueError("Too many points to plot (you'll crash your browser).")
else:
output_notebook(CDN, hide_banner=True, load_timeout=60000)
show(self.plot)
elif display_type == 'notebook':
output_notebook(CDN, hide_banner=True, load_timeout=60000)
show(self.plot)
elif display_type == 'file':
save(self.plot, filename='plot.html', resources=CDN, title='plot')
return 'plot.html'
elif display_type == 'object':
return self.plot
else:
raise ValueError('Invalid display type specified.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment