Skip to content

Instantly share code, notes, and snippets.

@benman1
Last active January 16, 2020 09:39
Show Gist options
  • Save benman1/456f365e392dff05f41f8713fe04bc74 to your computer and use it in GitHub Desktop.
Save benman1/456f365e392dff05f41f8713fe04bc74 to your computer and use it in GitHub Desktop.
Plot word cloud illustrating feature importance
from wordcloud import (WordCloud, get_single_color_func)
import matplotlib.pyplot as plt
from colour import Color
class GroupedColorFunc(object):
"""Create a color function object which assigns DIFFERENT SHADES of
specified colors to certain words based on the color to words mapping.
Uses wordcloud.get_single_color_func
Parameters
----------
color_to_words : dict(str -> list(str))
A dictionary that maps a color to the list of words.
default_color : str
Color that will be assigned to a word that's not a member
of any value from color_to_words.
Taken from https://amueller.github.io/word_cloud/auto_examples/
colored_by_group.html
"""
def __init__(self, color_to_words, default_color):
self.color_func_to_words = [
(get_single_color_func(color), set(words))
for (color, words) in color_to_words.items()]
self.default_color_func = get_single_color_func(default_color)
def get_color_func(self, word):
"""Returns a single_color_func associated with the word"""
try:
color_func = next(
color_func for (color_func, words) in self.color_func_to_words
if word in words)
except StopIteration:
color_func = self.default_color_func
return color_func
def __call__(self, word, **kwargs):
return self.get_color_func(word)(word, **kwargs)
def plot_feature_wordcloud(shap_values, feature_names, background:str='black',
thresh:float=0.01, zoom_level:float=10.0):
'''
This works with shap values and feature_importances_. The size of the words
corresponds to the importance of the word. With shap_values you'll get colors
indicating the directionality of the influence.
Parameters
----------
- shap_values: literally shap_values. Also takes feature_importances_ though
- feature_names: some matching names to the shap_values/feature_importances_
- background: choose black or white or any other color ('black')
- threshold: lower cutoff value for features to ignore (below 0.01 by default)
- zoom_level: get some more lower valued word features if you increase this one (10.0)
Example
-------
With shap values:
>> plot_feature_wordcloud(shap_values, jobtitle_vectorizer.get_feature_names()))
With feature_importances_:
>> plot_feature_wordcloud(
model.feature_importances_, jobtitle_vectorizer.get_feature_names()
)
'''
def bin_normalize(vec):
bins = np.histogram_bin_edges(vec, bins='auto')
directions_bins = np.digitize(vec, bins=bins)
directions_bins[directions_bins == bins.shape[0]] = bins.shape[0] - 1
return directions_bins
wc = WordCloud(
width=1600,
height=800,
collocations=False,
background_color=background,
max_font_size=50
)
feature_importances = np.tanh(np.abs(shap_values) / zoom_level)
directions = -np.sign(shap_values)
if shap_values.ndim > 1:
feature_importances = np.median(feature_importances, axis=0)
feature_importances = [0 if a_ < thresh else a_ for a_ in feature_importances]
directions = np.mean(directions, axis=0)
directions_bins = bin_normalize(directions)
#feature_importances = bin_normalize(feature_importances)
# red to green color spectrum:
colors = list(Color('#4d0000').range_to(Color('#00ff00'), bins.shape[0]))
directions_colors = [colors[m-1] for m in directions_bins]
# check: pd.Series(magnitude_colors).astype(str).value_counts()
color_to_words = {}
for color, word in zip(directions_colors, feature_names):
if color.get_hex() in color_to_words:
color_to_words[color.get_hex()] = color_to_words[color.get_hex()] + [word]
else:
color_to_words[color.get_hex()] = [word]
# sizes from feature importances:
wc = wc.generate_from_frequencies({
word: freq
for word, freq in
zip(
jobtitle_vectorizer.get_feature_names(),
feature_importances
)
})
default_color = 'grey'
grouped_color_func = GroupedColorFunc(color_to_words, default_color)
def color_func(word, font_size, position,orientation,
random_state=None, **kwargs):
return color_to_words[word]
# color by direction (positive-negative)
wc.recolor(color_func=grouped_color_func)
plt.figure(figsize=(20, 10))
plt.imshow(wc)
plt.axis('off')
plt.tight_layout(pad=0)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment