Skip to content

Instantly share code, notes, and snippets.

@konabuta
Last active April 18, 2020 15:36
Show Gist options
  • Save konabuta/7a72ca7c359b7fdfc677ab1a358d10b8 to your computer and use it in GitHub Desktop.
Save konabuta/7a72ca7c359b7fdfc677ab1a358d10b8 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.stats import wasserstein_distance,beta
wasserstein_distance([1,2,3,4],[1,2,3,4,4])
x = np.linspace(0, 1, 100)
dist1 = stats.beta.pdf(x,5,5)
dist2 = stats.beta.pdf(x,8,5)
ws_distance = wasserstein_distance(dist1,dist2)
fig, ax = plt.subplots(1, figsize=(8, 6))
ax.fill_between(x, dist1, alpha=0.5)
ax.fill_between(x, dist2, alpha=0.5)
ax.plot(0, 0, label="Wasserstein Distance \n = {:3.2f}".format(ws_distance), alpha=0)
ax.legend(loc=2, fontsize=15).get_frame().set_alpha(0)
@konabuta
Copy link
Author

wasserstein

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment