Skip to content

Instantly share code, notes, and snippets.

@drazenz
Created April 14, 2019 19:27
Show Gist options
  • Save drazenz/99e9a0a2b29a275170740eff0e215e4b to your computer and use it in GitHub Desktop.
Save drazenz/99e9a0a2b29a275170740eff0e215e4b to your computer and use it in GitHub Desktop.
# Step 1 - Make a scatter plot with square markers, set column names as labels
def heatmap(x, y, size):
fig, ax = plt.subplots()
# Mapping from column names to integer coordinates
x_labels = [v for v in sorted(x.unique())]
y_labels = [v for v in sorted(y.unique())]
x_to_num = {p[1]:p[0] for p in enumerate(x_labels)}
y_to_num = {p[1]:p[0] for p in enumerate(y_labels)}
size_scale = 500
ax.scatter(
x=x.map(x_to_num), # Use mapping for x
y=y.map(y_to_num), # Use mapping for y
s=size * size_scale, # Vector of square sizes, proportional to size parameter
marker='s' # Use square as scatterplot marker
)
# Show column labels on the axes
ax.set_xticks([x_to_num[v] for v in x_labels])
ax.set_xticklabels(x_labels, rotation=45, horizontalalignment='right')
ax.set_yticks([y_to_num[v] for v in y_labels])
ax.set_yticklabels(y_labels)
data = pd.read_csv('https://raw.githubusercontent.com/drazenz/heatmap/master/autos.clean.csv')
columns = ['bore', 'stroke', 'compression-ratio', 'horsepower', 'city-mpg', 'price']
corr = data[columns].corr()
corr = pd.melt(corr.reset_index(), id_vars='index') # Unpivot the dataframe, so we can get pair of arrays for x and y
corr.columns = ['x', 'y', 'value']
heatmap(
x=corr['x'],
y=corr['y'],
size=corr['value'].abs()
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment