Skip to content

Instantly share code, notes, and snippets.

@josharian

josharian/ve2.py Secret

Created October 18, 2024 16:58
Show Gist options
  • Save josharian/498a5951462721fa5a214d685dc32a53 to your computer and use it in GitHub Desktop.
Save josharian/498a5951462721fa5a214d685dc32a53 to your computer and use it in GitHub Desktop.
varentropy vs entropy, three outcomes
import numpy as np
import matplotlib.pyplot as plt
# Define the resolution of the grid
resolution = 500 # Higher resolution for smoother visualization
# Define x and y ranges covering the triangle
x = np.linspace(0, 1, resolution)
y = np.linspace(0, np.sqrt(3)/2, resolution)
X, Y = np.meshgrid(x, y)
# Compute p2
P2 = (2 / np.sqrt(3)) * Y
# Compute p1
P1 = X - 0.5 * P2
# Compute p0
P0 = 1 - P1 - P2
# Define a mask for valid probability distributions (where p0, p1, p2 >= 0)
valid = (P0 >= 0) & (P1 >= 0) & (P2 >= 0)
# Avoid log(0) by replacing zeros with a very small number
epsilon = 1e-12
P0_safe = np.clip(P0, epsilon, 1 - epsilon)
P1_safe = np.clip(P1, epsilon, 1 - epsilon)
P2_safe = np.clip(P2, epsilon, 1 - epsilon)
# Calculate entropy
entropy = - (P0_safe * np.log2(P0_safe) +
P1_safe * np.log2(P1_safe) +
P2_safe * np.log2(P2_safe))
entropy[~valid] = np.nan # Set invalid regions to NaN
# Calculate varentropy
log_P0 = -np.log2(P0_safe)
log_P1 = -np.log2(P1_safe)
log_P2 = -np.log2(P2_safe)
varentropy = (P0_safe * (log_P0 - entropy)**2 +
P1_safe * (log_P1 - entropy)**2 +
P2_safe * (log_P2 - entropy)**2)
varentropy[~valid] = np.nan # Set invalid regions to NaN
# Normalize entropy and varentropy to [0, 1]
entropy_min, entropy_max = np.nanmin(entropy), np.nanmax(entropy)
entropy_norm = (entropy - entropy_min) / (entropy_max - entropy_min)
entropy_norm[~valid] = 0 # Set invalid regions to zero
varentropy_min, varentropy_max = np.nanmin(varentropy), np.nanmax(varentropy)
varentropy_norm = (varentropy - varentropy_min) / (varentropy_max - varentropy_min)
varentropy_norm[~valid] = 0 # Set invalid regions to zero
# Map normalized entropy and varentropy to RGB channels
# Red channel: entropy
# Blue channel: varentropy
# Green channel: zero (or small value for visibility)
Red = entropy_norm
Blue = varentropy_norm
Green = np.zeros_like(Red)
# Stack RGB channels
RGB = np.stack((Red, Green, Blue), axis=-1)
# Set invalid regions to white or another neutral color
RGB[~valid] = [1, 1, 1] # White color for invalid regions
# Create a figure with subplots
fig = plt.figure(figsize=(12, 6))
gs = fig.add_gridspec(1, 2, width_ratios=[1, 0.5], wspace=0.3)
# Plot the main visualization
ax_main = fig.add_subplot(gs[0, 0])
ax_main.imshow(RGB, origin='lower', extent=(0, 1, 0, np.sqrt(3)/2), aspect='equal')
ax_main.set_xlabel('Probability Simplex X-coordinate')
ax_main.set_ylabel('Probability Simplex Y-coordinate')
ax_main.set_title('Entropy and Varentropy over Probability Simplex')
# Annotate the vertices with outcome labels
ax_main.text(0, 0, 'Outcome 0', ha='center', va='top')
ax_main.text(1, 0, 'Outcome 1', ha='center', va='top')
ax_main.text(0.5, np.sqrt(3)/2, 'Outcome 2', ha='center', va='bottom')
# Remove axes ticks for cleaner presentation
ax_main.set_xticks([])
ax_main.set_yticks([])
# Create the 2D color legend
# Generate normalized entropy and varentropy grids
legend_resolution = 256
v_norm = np.linspace(0, 1, legend_resolution)
e_norm = np.linspace(0, 1, legend_resolution)
V_norm, E_norm = np.meshgrid(v_norm, e_norm)
# Map to RGB colors using the same mapping
Red_leg = E_norm
Blue_leg = V_norm
Green_leg = np.zeros_like(Red_leg)
RGB_leg = np.stack((Red_leg, Green_leg, Blue_leg), axis=-1)
# Plot the 2D legend
ax_leg = fig.add_subplot(gs[0, 1])
ax_leg.imshow(RGB_leg, origin='lower', extent=(varentropy_min, varentropy_max, entropy_min, entropy_max), aspect='auto')
ax_leg.set_xlabel('Varentropy')
ax_leg.set_ylabel('Entropy')
ax_leg.set_title('Color Legend')
ax_leg.set_xticks(np.linspace(varentropy_min, varentropy_max, 5))
ax_leg.set_yticks(np.linspace(entropy_min, entropy_max, 5))
# Adjust the ticks and labels for better readability
from matplotlib.ticker import MaxNLocator, FormatStrFormatter
ax_leg.xaxis.set_major_locator(MaxNLocator(5))
ax_leg.yaxis.set_major_locator(MaxNLocator(5))
ax_leg.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax_leg.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment