Skip to content

Instantly share code, notes, and snippets.

@Sangarshanan
Last active February 14, 2021 13:34
Show Gist options
  • Save Sangarshanan/a258462621a813f898e4013ac5f22979 to your computer and use it in GitHub Desktop.
Save Sangarshanan/a258462621a813f898e4013ac5f22979 to your computer and use it in GitHub Desktop.
"""
Folium choloropeth with suport for Branca colormaps and custom feature functions
"""
import folium
import warnings
import numpy as np
from folium.map import FeatureGroup
from branca.utilities import color_brewer
from branca.colormap import ColorMap, StepColormap, LinearColormap
class Choropleth(FeatureGroup):
def __init__(
self,
geo_data,
data=None,
columns=None,
key_on=None, # noqa
bins=6,
fill_color=None,
nan_fill_color="black",
fill_opacity=0.6,
nan_fill_opacity=None,
line_color="black",
line_weight=1,
line_opacity=1,
name=None,
legend_name="",
overlay=True,
control=True,
show=True,
topojson=None,
smooth_factor=None,
highlight=None,
**kwargs
):
super(Choropleth, self).__init__(
name=name, overlay=overlay, control=control, show=show
)
self._name = "Choropleth"
if "threshold_scale" in kwargs:
if kwargs["threshold_scale"] is not None:
bins = kwargs["threshold_scale"]
warnings.warn(
"choropleth `threshold_scale` parameter is now depreciated "
"in favor of the `bins` parameter.",
DeprecationWarning,
)
# Create color_data dict
if hasattr(data, "set_index"):
# This is a pd.DataFrame
color_data = data.set_index(columns[0])[columns[1]].to_dict()
elif hasattr(data, "to_dict"):
# This is a pd.Series
color_data = data.to_dict()
elif data:
color_data = dict(data)
else:
color_data = None
self.color_scale = None
if isinstance(fill_color, (ColorMap, LinearColormap, StepColormap)):
def color_scale_fun(feature):
color_dict = {key: fill_color(color_data[key]) for key in color_data.keys()}
return color_dict[feature["properties"][columns[0]]], fill_opacity
elif callable(fill_color):
def color_scale_fun(feature):
return fill_color(feature["properties"][columns[1]]), fill_opacity
else:
fill_color = fill_color or ("blue" if data is None else "Blues")
if data is not None and not color_brewer(fill_color):
raise ValueError(
"Please pass a valid color brewer code to "
"fill_local. See docstring for valid codes."
)
if nan_fill_opacity is None:
nan_fill_opacity = fill_opacity
if color_data is not None and key_on is not None:
real_values = np.array(list(color_data.values()))
real_values = real_values[~np.isnan(real_values)]
_, bin_edges = np.histogram(real_values, bins=bins)
bins_min, bins_max = min(bin_edges), max(bin_edges)
if np.any((real_values < bins_min) | (real_values > bins_max)):
raise ValueError(
"All values are expected to fall into one of the provided "
"bins (or to be Nan). Please check the `bins` parameter "
"and/or your data."
)
# We add the colorscale
nb_bins = len(bin_edges) - 1
color_range = color_brewer(fill_color, n=nb_bins)
self.color_scale = StepColormap(
color_range,
index=bin_edges,
vmin=bins_min,
vmax=bins_max,
caption=legend_name,
)
# then we 'correct' the last edge for numpy digitize
# (we add a very small amount to fake an inclusive right interval)
increasing = bin_edges[0] <= bin_edges[-1]
bin_edges[-1] = np.nextafter(
bin_edges[-1], (1 if increasing else -1) * np.inf
)
key_on = key_on[8:] if key_on.startswith("feature.") else key_on
def get_by_key(obj, key):
return (
obj.get(key, None)
if len(key.split(".")) <= 1
else get_by_key(
obj.get(key.split(".")[0], None), ".".join(key.split(".")[1:])
)
)
def color_scale_fun(x):
key_of_x = get_by_key(x, key_on)
if key_of_x is None:
raise ValueError(
"key_on `{!r}` not found in GeoJSON.".format(key_on)
)
if key_of_x not in color_data.keys():
return nan_fill_color, nan_fill_opacity
value_of_x = color_data[key_of_x]
if np.isnan(value_of_x):
return nan_fill_color, nan_fill_opacity
color_idx = np.digitize(value_of_x, bin_edges, right=False) - 1
return color_range[color_idx], fill_opacity
else:
def color_scale_fun(x):
return fill_color, fill_opacity
def style_function(x):
color, opacity = color_scale_fun(x)
return {
"weight": line_weight,
"opacity": line_opacity,
"color": line_color,
"fillOpacity": opacity,
"fillColor": color,
}
def highlight_function(x):
return {"weight": line_weight + 2, "fillOpacity": fill_opacity + 0.2}
if topojson:
self.geojson = folium.TopoJson(
geo_data,
topojson,
style_function=style_function,
smooth_factor=smooth_factor,
)
else:
self.geojson = folium.GeoJson(
geo_data,
style_function=style_function,
smooth_factor=smooth_factor,
highlight_function=highlight_function if highlight else None,
)
self.add_child(self.geojson)
if self.color_scale:
self.add_child(self.color_scale)
def render(self, **kwargs):
"""Render the GeoJson/TopoJson and color scale objects."""
if self.color_scale:
# ColorMap needs Map as its parent
assert isinstance(self._parent, folium.Map), (
"Choropleth must be added" " to a Map object."
)
self.color_scale._parent = self._parent
super(Choropleth, self).render(**kwargs)
import geopandas
import folium
df = geopandas.read_file(geopandas.datasets.get_path("naturalearth_lowres"))
df["id"] = range(len(df))
# Normal Plots
m = folium.Map(location=[8.7832, 34.5085], zoom_start=1)
c = Choropleth(
geo_data=df.__geo_interface__,
name="Map",
data=df,
columns=["id", "pop_est"],
key_on="feature.properties.id",
fill_color='YlGnBu',
)
m.add_child(c)
# Custom Branca Plots
import branca
from branca.colormap import ColorMap,StepColormap,LinearColormap
step = branca.colormap.linear.viridis.scale(df['pop_est'].min(), df['pop_est'].max())
m = folium.Map(location=[8.7832, 34.5085], zoom_start=1)
c = Choropleth(
geo_data=df.__geo_interface__,
name="Map",
data=df,
columns=["id", "pop_est"],
key_on="feature.properties.id",
fill_color=step,
)
m.add_child(c)
# Custom Function Plots
def my_color_function(field):
"""Maps low values to green and high values to red."""
if field > 100000000:
return "#ff0000"
else:
return "#008000"
m = folium.Map(location=[8.7832, 34.5085], zoom_start=1)
c = Choropleth(
geo_data=df.__geo_interface__,
name="Map",
data=df,
columns=["id", "pop_est"],
key_on="feature.properties.id",
fill_color=my_color_function,
)
m.add_child(c)
m
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment