Skip to content

Instantly share code, notes, and snippets.

@shtaxxx
Last active April 21, 2019 14:10
Show Gist options
  • Save shtaxxx/6ca20df2cb7933291fdb9cb02ccf2088 to your computer and use it in GitHub Desktop.
Save shtaxxx/6ca20df2cb7933291fdb9cb02ccf2088 to your computer and use it in GitHub Desktop.
KL divergence comparison between different quantized distributions generated from an individual distribution.
from __future__ import absolute_import
from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
N = 1000 * 1000
loc = 0
scale = 2
epsilon = 0.00001
ref_num_bins = 1024
q_num_bins = 16
dist = np.random.normal(loc, scale, N)
q1_dist = np.clip(dist, -20.0, 20.0)
q2_dist = np.clip(dist, -7.0, 7.0)
ref_hist, ref_bins = np.histogram(dist, bins=ref_num_bins, density=True)
q1_hist, q1_bins = np.histogram(q1_dist, bins=q_num_bins, density=True)
q2_hist, q2_bins = np.histogram(q2_dist, bins=q_num_bins, density=True)
def to_hist_with_orig_bins(targ_hist, targ_bins, orig_hist, orig_bins):
targ_v = 0.0
targ_i = 0
targ_bin = targ_bins[0]
ret_hist = np.zeros_like(orig_hist)
for i, orig_bin in enumerate(orig_bins[:-1]):
if targ_bin <= orig_bin:
if targ_i < len(targ_bins) - 1:
targ_v = targ_hist[targ_i]
targ_i += 1
targ_bin = targ_bins[targ_i]
else:
targ_v = 0.0
targ_bin = orig_bin.max() + 1.0
ret_hist[i] = targ_v
return ret_hist
c_q1_hist = to_hist_with_orig_bins(q1_hist, q1_bins, ref_hist, ref_bins)
c_q2_hist = to_hist_with_orig_bins(q2_hist, q2_bins, ref_hist, ref_bins)
pad_ref_bins = np.pad(ref_bins, [1, 0], 'constant')
sumd = np.sum((ref_bins - pad_ref_bins[:-1])[1:])
ref_hist = (ref_hist + epsilon) / (1.0 + epsilon * sumd)
c_q1_hist = (c_q1_hist + epsilon) / (1.0 + epsilon * sumd)
c_q2_hist = (c_q2_hist + epsilon) / (1.0 + epsilon * sumd)
kl_ref = np.sum(ref_hist * np.log(ref_hist / ref_hist))
kl_c_q1 = np.sum(ref_hist * np.log(ref_hist / c_q1_hist))
kl_c_q2 = np.sum(ref_hist * np.log(ref_hist / c_q2_hist))
def to_labels(bins):
labels = []
for i in range(len(bins) - 1):
labels.append((bins[i] + bins[i + 1]) / 2)
return labels
ref_labels = to_labels(ref_bins)
q1_labels = to_labels(q1_bins)
q2_labels = to_labels(q2_bins)
plt.figure(figsize=(10, 5))
#plt.bar(ref_labels, ref_hist, label='ref')
plt.plot(ref_labels, ref_hist, label='ref')
plt.plot(q1_labels, q1_hist, label='q1')
plt.plot(q2_labels, q2_hist, label='q2')
plt.plot(ref_labels, c_q1_hist, label='q1 KL=%f' % kl_c_q1)
plt.plot(ref_labels, c_q2_hist, label='q2 KL=%f' % kl_c_q2)
plt.legend(title='histogram', loc='best')
plt.grid()
# plt.show()
plt.savefig('out.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment