Skip to content

Instantly share code, notes, and snippets.

@HYChou0515
Created May 31, 2022 14:35
Show Gist options
  • Save HYChou0515/da82f2255674d55e5dfdc7576214da65 to your computer and use it in GitHub Desktop.
Save HYChou0515/da82f2255674d55e5dfdc7576214da65 to your computer and use it in GitHub Desktop.
import itertools
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
def get_random_dist(seed=None):
rng = np.random.default_rng(seed)
nsamples = int(np.exp(rng.random()*2+3))
dists = [
lambda: rng.normal(rng.random()*20-10, rng.random()*5),
lambda: rng.normal(rng.random()*20-5, rng.random()*4),
lambda: rng.normal(rng.random()*20-15, rng.random()*2),
lambda: rng.normal(rng.random()*20, rng.random()*4),
lambda: rng.normal(rng.random()*20-20, rng.random()*5),
lambda: rng.random()*50-25,
lambda: rng.random() * 20 - 25,
lambda: rng.random() * 40 - 25,
lambda: rng.random() * 10 + 2,
lambda: rng.random() * 12 - 1,
lambda: rng.poisson(rng.random()*5+1),
lambda: rng.poisson(rng.random() * 3 + 1),
lambda: rng.poisson(rng.random() * 6 + 1),
lambda: rng.poisson(rng.random() * 10 + 1),
lambda: rng.poisson(rng.random() * 21 + 1),
]
weights = np.power(np.arange(len(dists))+rng.beta(1, 5)*len(dists), 5)
weights /= sum(weights)
rng.shuffle(weights)
a = []
for i in range(nsamples):
a.append(rng.choice(dists, p=weights)())
return a
def get_score(yes, no):
def _get_fisher(cut):
a, b, c, d = 0, 0, 0, 0
for x in yes:
if x > cut:
a+=1
else:
c+=1
for x in no:
if x > cut:
b+=1
else:
d+=1
pvalue = stats.fisher_exact([[a, b], [c, d]], alternative='greater')[1]
return pvalue
pfisher = min(_get_fisher(cut) for cut in itertools.chain(yes, no))
pttest = stats.ttest_1samp(yes, np.median([*yes, *no]), alternative='greater').pvalue
s = stats.gmean([pfisher, pttest])
if np.isnan(s):
s = 1.0
s = np.clip(s, 0, 1)
return np.clip(np.interp(-np.log10(s), [0, 1, 4, 100], [0, 60, 90, 100]), 0, 100)
dists = []
for i in range(100):
d = [get_random_dist(), get_random_dist()]
dists.append([*d, get_score(d[0], d[1])])
# dists = sorted(dists, key=lambda x: x[2])
# fig, axes = plt.subplots(nrows=2, ncols=3)
# for i in range(3):
# axes[0][i].scatter(np.random.random(len(dists[i][0])), dists[i][0], marker='o')
# axes[0][i].scatter(np.random.random(len(dists[i][1]))+3, dists[i][1], marker='x')
#
# for i in range(3):
# axes[1][i].scatter(np.random.random(len(dists[-1-i][0])), dists[-1-i][0], marker='o')
# axes[1][i].scatter(np.random.random(len(dists[-1-i][1]))+3, dists[-1-i][1], marker='x')
# fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment