-
-
Save josharian/498a5951462721fa5a214d685dc32a53 to your computer and use it in GitHub Desktop.
varentropy vs entropy, three outcomes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
# Define the resolution of the grid | |
resolution = 500 # Higher resolution for smoother visualization | |
# Define x and y ranges covering the triangle | |
x = np.linspace(0, 1, resolution) | |
y = np.linspace(0, np.sqrt(3)/2, resolution) | |
X, Y = np.meshgrid(x, y) | |
# Compute p2 | |
P2 = (2 / np.sqrt(3)) * Y | |
# Compute p1 | |
P1 = X - 0.5 * P2 | |
# Compute p0 | |
P0 = 1 - P1 - P2 | |
# Define a mask for valid probability distributions (where p0, p1, p2 >= 0) | |
valid = (P0 >= 0) & (P1 >= 0) & (P2 >= 0) | |
# Avoid log(0) by replacing zeros with a very small number | |
epsilon = 1e-12 | |
P0_safe = np.clip(P0, epsilon, 1 - epsilon) | |
P1_safe = np.clip(P1, epsilon, 1 - epsilon) | |
P2_safe = np.clip(P2, epsilon, 1 - epsilon) | |
# Calculate entropy | |
entropy = - (P0_safe * np.log2(P0_safe) + | |
P1_safe * np.log2(P1_safe) + | |
P2_safe * np.log2(P2_safe)) | |
entropy[~valid] = np.nan # Set invalid regions to NaN | |
# Calculate varentropy | |
log_P0 = -np.log2(P0_safe) | |
log_P1 = -np.log2(P1_safe) | |
log_P2 = -np.log2(P2_safe) | |
varentropy = (P0_safe * (log_P0 - entropy)**2 + | |
P1_safe * (log_P1 - entropy)**2 + | |
P2_safe * (log_P2 - entropy)**2) | |
varentropy[~valid] = np.nan # Set invalid regions to NaN | |
# Normalize entropy and varentropy to [0, 1] | |
entropy_min, entropy_max = np.nanmin(entropy), np.nanmax(entropy) | |
entropy_norm = (entropy - entropy_min) / (entropy_max - entropy_min) | |
entropy_norm[~valid] = 0 # Set invalid regions to zero | |
varentropy_min, varentropy_max = np.nanmin(varentropy), np.nanmax(varentropy) | |
varentropy_norm = (varentropy - varentropy_min) / (varentropy_max - varentropy_min) | |
varentropy_norm[~valid] = 0 # Set invalid regions to zero | |
# Map normalized entropy and varentropy to RGB channels | |
# Red channel: entropy | |
# Blue channel: varentropy | |
# Green channel: zero (or small value for visibility) | |
Red = entropy_norm | |
Blue = varentropy_norm | |
Green = np.zeros_like(Red) | |
# Stack RGB channels | |
RGB = np.stack((Red, Green, Blue), axis=-1) | |
# Set invalid regions to white or another neutral color | |
RGB[~valid] = [1, 1, 1] # White color for invalid regions | |
# Create a figure with subplots | |
fig = plt.figure(figsize=(12, 6)) | |
gs = fig.add_gridspec(1, 2, width_ratios=[1, 0.5], wspace=0.3) | |
# Plot the main visualization | |
ax_main = fig.add_subplot(gs[0, 0]) | |
ax_main.imshow(RGB, origin='lower', extent=(0, 1, 0, np.sqrt(3)/2), aspect='equal') | |
ax_main.set_xlabel('Probability Simplex X-coordinate') | |
ax_main.set_ylabel('Probability Simplex Y-coordinate') | |
ax_main.set_title('Entropy and Varentropy over Probability Simplex') | |
# Annotate the vertices with outcome labels | |
ax_main.text(0, 0, 'Outcome 0', ha='center', va='top') | |
ax_main.text(1, 0, 'Outcome 1', ha='center', va='top') | |
ax_main.text(0.5, np.sqrt(3)/2, 'Outcome 2', ha='center', va='bottom') | |
# Remove axes ticks for cleaner presentation | |
ax_main.set_xticks([]) | |
ax_main.set_yticks([]) | |
# Create the 2D color legend | |
# Generate normalized entropy and varentropy grids | |
legend_resolution = 256 | |
v_norm = np.linspace(0, 1, legend_resolution) | |
e_norm = np.linspace(0, 1, legend_resolution) | |
V_norm, E_norm = np.meshgrid(v_norm, e_norm) | |
# Map to RGB colors using the same mapping | |
Red_leg = E_norm | |
Blue_leg = V_norm | |
Green_leg = np.zeros_like(Red_leg) | |
RGB_leg = np.stack((Red_leg, Green_leg, Blue_leg), axis=-1) | |
# Plot the 2D legend | |
ax_leg = fig.add_subplot(gs[0, 1]) | |
ax_leg.imshow(RGB_leg, origin='lower', extent=(varentropy_min, varentropy_max, entropy_min, entropy_max), aspect='auto') | |
ax_leg.set_xlabel('Varentropy') | |
ax_leg.set_ylabel('Entropy') | |
ax_leg.set_title('Color Legend') | |
ax_leg.set_xticks(np.linspace(varentropy_min, varentropy_max, 5)) | |
ax_leg.set_yticks(np.linspace(entropy_min, entropy_max, 5)) | |
# Adjust the ticks and labels for better readability | |
from matplotlib.ticker import MaxNLocator, FormatStrFormatter | |
ax_leg.xaxis.set_major_locator(MaxNLocator(5)) | |
ax_leg.yaxis.set_major_locator(MaxNLocator(5)) | |
ax_leg.xaxis.set_major_formatter(FormatStrFormatter('%.2f')) | |
ax_leg.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment