Last active
May 8, 2018 17:13
-
-
Save schaunwheeler/9a98d8bee5f039e9872c76fb24a6e69c to your computer and use it in GitHub Desktop.
Example of workflow automation using Bokeh. Copyright (c) 2018 Valassis Digital under the terms of the BSD 3-Clause license.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Code described in Medium post https://medium.com/@schaun.wheeler/codify-your-workflow-377f5f8bf4c3 | |
Copyright (c) 2018 Valassis Digital under the terms of the BSD 3-Clause license. | |
""" | |
from math import sin, pi, log, floor | |
from shapely.wkt import loads, dumps | |
from shapely.geometry import Polygon, Point, MultiPolygon, LineString, MultiLineString, MultiPoint, GeometryCollection | |
from bokeh.io import output_notebook, show, output_file, reset_output, save | |
from bokeh.models import GMapPlot, GMapOptions, ColumnDataSource, Range1d | |
from bokeh.models.tools import HoverTool, WheelZoomTool, ResetTool, PanTool | |
from bokeh.models.renderers import GlyphRenderer | |
from bokeh.resources import CDN | |
from os import getcwd | |
from pandas import Series, DataFrame | |
from numpy import mean | |
from geohash2 import decode_exactly, decode, encode | |
class geoPlotter(object): | |
""" | |
This class provides a wrapper to produce most of the boilerplate needed to use Bokeh to plot on | |
top of Google Maps images. | |
Example usage: | |
geohashes = [ | |
'dnrgrfm', 'dnrgrf3', 'dnrgrf7', 'dnrgrf5', 'dnrgrfk', 'dnrgrf1', | |
'dnrgrfh', 'dnrgrf2', 'dnrgrcu', 'dnrgrfj', 'dnrgrcv', 'dnrgrf4' | |
] | |
API_KEY = (GET API KEY AT developers.google.com/maps/documentation/javascript/get-api-key) | |
gp = geoPlotter(api_key=API_KEY) | |
gp.add_source(geohashes, label='a', order=range(len(geohashes))) | |
gp.prepare_plot(plot_width=700) | |
gp.add_layer('a', Patches, color='yellow', alpha=0.5) | |
gp.add_layer('a', Circle, color='blue', size=20, alpha=0.75) | |
gp.add_layer('a', Text, text='order', text_color='white', text_font_size='10pt', text_align='center', text_baseline='middle') | |
gp.render_plot('auto') | |
""" | |
def __init__(self, api_key): | |
""" | |
To initialize the class, you must include a Google Maps API key, which can be obtained at | |
developers.google.com/maps/documentation/javascript/get-api-key | |
""" | |
self.api_key = api_key | |
self.sources = dict() | |
self.xmin = None | |
self.ymin = None | |
self.xmax = None | |
self.ymax = None | |
self.plot = None | |
@staticmethod | |
def _validate_geohash(value): | |
"""Check whether an input value is a valid geohash""" | |
if len(value) > 12: | |
return False | |
base32 = '0123456789bcdefghjkmnpqrstuvwxyz' | |
return all(v in base32 for v in value) | |
@staticmethod | |
def _validate_wellknowntext(value): | |
"""Check whether an input value is valid well-known text""" | |
valid_types = ['GEOMETRY', 'POINT', 'MULTIPOINT', 'LINESTRING', 'MULTILINESTRING', 'POLYGON', 'MULTIPOLYGON'] | |
return any(value.startswith(vt) for vt in valid_types) | |
@staticmethod | |
def _validate_shapelyobject(value): | |
"""Check whether an input value is a valid shapely object.""" | |
return issubclass(type(value), (Polygon, Point, MultiPolygon, LineString, MultiLineString, MultiPoint, GeometryCollection)) | |
@staticmethod | |
def _geohash_to_coords(value): | |
"""Convert a geohash to the x and y values of that geohash's bounding box.""" | |
y, x, y_margin, x_margin = decode_exactly(value) | |
x_coords = [ | |
x - x_margin, | |
x - x_margin, | |
x + x_margin, | |
x + x_margin | |
] | |
y_coords = [ | |
y + y_margin, | |
y - y_margin, | |
y - y_margin, | |
y + y_margin | |
] | |
return [x_coords], [y_coords] | |
@staticmethod | |
def _shape_to_coords(value, wkt=False): | |
"""Convert a shape (a shapely object or well-known text) to x and y coordinates.""" | |
x_coords = list() | |
y_coords = list() | |
if wkt: | |
value = loads(value) | |
if hasattr(value, '__len__'): | |
for v in value: | |
if not hasattr(v, 'exterior'): | |
v = v.buffer(0) | |
x_coords.append(v.exterior.coords.xy[0].tolist()) | |
y_coords.append(v.exterior.coords.xy[1].tolist()) | |
else: | |
if not hasattr(value, 'exterior'): | |
value = value.buffer(0) | |
x_coords.append(value.exterior.coords.xy[0].tolist()) | |
y_coords.append(value.exterior.coords.xy[1].tolist()) | |
return x_coords, y_coords | |
@staticmethod | |
def _lat_rad(lat): | |
"Convert a latitude to radians (for estimating zoom factor for Google Maps)." | |
sine = sin(lat * pi / 180.); | |
rad_x2 = log((1 + sine) / (1 - sine)) / 2. | |
return max(min(rad_x2, pi), -pi) / 2. | |
@staticmethod | |
def _zoom(map_px, world_px, fraction): | |
"""Calculate the zoom factor for Google Maps.""" | |
try: | |
return floor(log(map_px / world_px / fraction) / log(2)); | |
except ZeroDivisionError: | |
return None | |
def _estimate_zoom_level(self, plot_width, height_max=600, zoom_max=21, world_dim=256): | |
""" | |
Given a desired plot width and data source(s), estimate the best zoom factor for Google Maps. | |
Parameters: | |
plot_width (int): desired plot width, in pixels | |
height_max (int): maximum height allowable for the plot | |
zoom_max (int): maximum zoom factor (21 is the maximum Google Maps allows) | |
world_dim (int): the number of dimensions fo the world map (256 is the Google default) | |
Returns: | |
Zoom factor (int): the zoom factor to be used in the call to Google Maps | |
y_center (float): the central latitude for the plot | |
x_center (float): the central longitude for the plot | |
plot_height (int): the height to be used for the plot | |
""" | |
yrange = abs(self.ymax - self.ymin) | |
xrange = abs(self.xmax - self.xmin) | |
if (xrange == 0) or (yrange == 0): | |
plot_height = height_max | |
else: | |
plot_height=int((plot_width / xrange) * yrange) | |
if plot_height < plot_width: | |
plot_height = plot_width | |
if plot_height > height_max: | |
plot_height = height_max | |
y_center = (self.ymin + self.ymax) / 2.0 | |
x_center = (self.xmin + self.xmax) / 2.0 | |
lat_fraction = (self._lat_rad(self.ymax) - self._lat_rad(self.ymin)) / pi | |
lng_diff = self.xmax - self.xmin | |
lng_fraction = ((lng_diff + 360) if (lng_diff < 0) else lng_diff) / 360. | |
lat_zoom = self._zoom(plot_height, world_dim, lat_fraction) | |
lng_zoom = self._zoom(plot_width, world_dim, lng_fraction) | |
if (lat_zoom is None) and (lng_zoom is None): | |
return zoom_max, y_center, x_center, plot_height | |
elif lat_zoom is None: | |
return int(min(lng_zoom, zoom_max)), y_center, x_center, plot_height | |
elif lng_zoom is None: | |
return int(min(lat_zoom, zoom_max)), y_center, x_center, plot_height | |
else: | |
return int(min(lat_zoom, lng_zoom, zoom_max)), y_center, x_center, plot_height | |
def _process_input_value(self, value): | |
"""Router function for values: take an arbitrary value, validate, and return standardized coordinate output.""" | |
if type(value) == str: | |
if self._validate_geohash(value): | |
return self._geohash_to_coords(value) | |
elif self._validate_wellknowntext(value): | |
return self._shape_to_coords(value, wkt=True) | |
else: | |
raise ValueError('Unparseable string input.') | |
elif type(value) in (tuple, list): | |
if len(value) == 2: | |
if all(type(v) == float for v in value): | |
return tuple([[v]] for v in value) | |
else: | |
raise ValueError('Unparseable list input.') | |
else: | |
raise ValueError('List contains too many elements.') | |
elif self._validate_shapelyobject(value): | |
return self._shape_to_coords(value, wkt=False) | |
else: | |
raise ValueError('Unrecognizeable input.') | |
def _set_coordinate_bounds(self, df): | |
"""Given a new source, set or update coordinate bounds for the plot.""" | |
xmin = df['x_coords'].apply(min).min() | |
xmax = df['x_coords'].apply(max).max() | |
ymin = df['y_coords'].apply(min).min() | |
ymax = df['y_coords'].apply(max).max() | |
if self.xmin is not None: | |
if xmin < self.xmin: | |
self.xmin = xmin | |
else: | |
self.xmin = xmin | |
if self.xmax is not None: | |
if xmax > self.xmax: | |
self.xmax = xmax | |
else: | |
self.xmax = xmax | |
if self.ymin is not None: | |
if ymin < self.ymin: | |
self.ymin = ymin | |
else: | |
self.ymin = ymin | |
if self.ymax is not None: | |
if ymax < self.ymax: | |
self.ymax = ymax | |
else: | |
self.ymax = ymax | |
def add_source(self, data, label, column_name=None, **kwargs): | |
""" | |
Add a source to the `self.sources`, to be used in the plot. | |
Parameters: | |
data (list or DataFrame): a list of geohashes, shapely objects, well-known text strings, or longitude/latitude pairs; | |
or a DataFrame where the objects to be plotted are indicated by `column_name` | |
label (str): a name that can be used to reference the source going forward | |
column_name (str): optional, the column of `data` that contains geohashes, shapely objects, well-known text strings, | |
or longitude/latitude pairs | |
kwargs: lists of the same length as `data`. These will be appended to the data as metadata. | |
""" | |
# process data into x and y coordinates | |
if type(data) == DataFrame: | |
if column_name is None: | |
raise ValueError('If data is a dataframe then column_name must be specified.') | |
df = data.copy() | |
keys = [c for c in df.columns if c != column_name] | |
else: | |
df = Series(data).to_frame('raw_data') | |
keys = [k for k in kwargs.keys()] | |
column_name = 'raw_data' | |
df['processed_data'] = df[column_name].apply(self._process_input_value) | |
df['x_coords'] = df['processed_data'].apply(lambda v: v[0]) | |
df['y_coords'] = df['processed_data'].apply(lambda v: v[1]) | |
df = df.drop([column_name, 'processed_data'], axis=1) | |
# add kwargs as metadata | |
vals = ['x_coords', 'y_coords'] | |
if len(keys) > 0: | |
df2 = DataFrame(columns=keys + ['i'] + vals).set_index(keys + ['i']) | |
else: | |
df2 = DataFrame(columns=keys + ['i'] + vals).set_index('i', append=True) | |
if len(kwargs) > 0: | |
for k, v in kwargs.items(): | |
df[k] = v | |
df = df.set_index(keys) | |
# in cases where data inputs are multipolygons, split out into separate rows with metadata correctly associated | |
for ind, row in df.iterrows(): | |
nrep = len(row['x_coords']) | |
nrep_check = len(row['y_coords']) | |
if nrep != nrep_check: | |
raise AssertionError('X and Y coordinate lists are not equal in length.') | |
for i in range(nrep): | |
ind2 = (ind, i,) if type(ind) != tuple else ind + (i,) | |
df2.loc[ind2, ['x_coords', 'y_coords']] = None, None | |
df2.loc[ind2, ['x_coords', 'y_coords']] = row['x_coords'][i], row['y_coords'][i] | |
self._set_coordinate_bounds(df2) | |
self.sources[label] = ColumnDataSource(df2.reset_index().drop(['i'], axis=1)) | |
def prepare_plot(self, plot_width=700, plot_height=None, zoom=None, map_type='hybrid', **kwargs): | |
""" | |
Create the actual plot object (stored in `self.plot`). | |
Parameters: | |
plot_width (int): desired plot width, in pixels | |
plot_height (int): desired plot height, will be calculated automatically if not supplied | |
zoom (int): zoom factor for Google Maps, will be calculated automatically if not supplied | |
map_type (string): 'satellite', 'roadmap', or 'hybrid' | |
kwargs: any options passed to Bokeh GMapPlot (title, etc.) | |
""" | |
zoom_level, lat_center, lng_center, auto_plot_height = self._estimate_zoom_level(plot_width) | |
if plot_height is None: | |
plot_height = auto_plot_height | |
if zoom is None: | |
zoom = zoom_level | |
map_options = GMapOptions(lat=lat_center, lng=lng_center, map_type=map_type, zoom=zoom) | |
self.plot = GMapPlot( | |
x_range=Range1d(), | |
y_range=Range1d(), | |
map_options=map_options, | |
plot_width=plot_width, | |
plot_height=plot_height, | |
**kwargs | |
) | |
self.plot.api_key = self.api_key | |
self.plot.add_tools(WheelZoomTool(), ResetTool(), PanTool()) | |
def add_layer(self, source_label, bokeh_model, tooltips=None, **kwargs): | |
""" | |
Add bokeh models (glyphs or markers) to `self.plot`. | |
`self.prepare_plot` must have been called previous to this. | |
Parameters: | |
source_label (str): string corresponding to a label previously called in `self.add_source` | |
bokeh_model: any Bokeh model or glyph class | |
tooltips: string or list of tuples (passed to Bokeh HoverTool) | |
kwargs: options passed to the objected for `bokeh_model` | |
This method allows two special kwargs: 'color' and 'alpha'. When used with a bokeh model that has 'fill_color' | |
and 'line_color' and 'fill_alpha' and 'line_alpha' properties, calling the special kwarg will use the same value | |
for both. | |
""" | |
if self.plot is None: | |
raise AssertionError('self.plot is null; call `self.prepare_plot`.') | |
if 'color' in kwargs.keys(): | |
color = kwargs.pop('color') | |
kwargs['fill_color'] = color | |
kwargs['line_color'] = color | |
if 'alpha' in kwargs.keys(): | |
alpha = kwargs.pop('alpha') | |
kwargs['fill_alpha'] = alpha | |
kwargs['line_alpha'] = alpha | |
try: | |
model_object = bokeh_model(xs='x_coords', ys='y_coords', name=source_label, **kwargs) | |
except AttributeError: | |
model_object = bokeh_model(x='x_coords_point', y='y_coords_point', name=source_label, **kwargs) | |
if not all(c in self.sources[source_label].column_names for c in ['x_coord_point', 'y_coord_point']): | |
self.sources[source_label].data['x_coords_point'] = [mean(x) for x in self.sources[source_label].data['x_coords']] | |
self.sources[source_label].data['y_coords_point'] = [mean(x) for x in self.sources[source_label].data['y_coords']] | |
self.sources[source_label].column_names.extend(['x_coords_point', 'y_coords_point']) | |
rend = self.plot.add_glyph(self.sources[source_label], model_object) | |
if tooltips is not None: | |
self.plot.add_tools(HoverTool(tooltips=tooltips, renderers=[rend])) | |
def render_plot(self, display_type='object'): | |
""" | |
Pull everything together into a plot ready for display. | |
Parameters: | |
display_plot (str): either 'object', 'auto', 'notebook', or 'file'. If 'object', it returns | |
the plot object. If 'notebook', the plot is displayed in the notebok. If 'file', the plot | |
is saved to 'plot.html' in the current working directory. If 'auto', then it chooses to | |
plot in notebook or file (or not plot at all) depending on how many data points the plot has. | |
This can prevent the plot from freezing up a notebook. | |
""" | |
n_points = 0 | |
for v in self.sources.values(): | |
try: | |
n_points += sum(len(x) for x in v.data['x_coords']) | |
except TypeError: | |
n_points += len(v.data) | |
if display_type == 'auto': | |
if n_points > 100000: | |
save(self.plot, filename='plot.html', resources=CDN, title='plot') | |
return 'plot.html' | |
elif n_points > 1000000: | |
raise ValueError("Too many points to plot (you'll crash your browser).") | |
else: | |
output_notebook(CDN, hide_banner=True, load_timeout=60000) | |
show(self.plot) | |
elif display_type == 'notebook': | |
output_notebook(CDN, hide_banner=True, load_timeout=60000) | |
show(self.plot) | |
elif display_type == 'file': | |
save(self.plot, filename='plot.html', resources=CDN, title='plot') | |
return 'plot.html' | |
elif display_type == 'object': | |
return self.plot | |
else: | |
raise ValueError('Invalid display type specified.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment