Skip to content

Instantly share code, notes, and snippets.

@kevindoran
Created September 6, 2023 13:39
Show Gist options
  • Save kevindoran/fce2a03739940a193351247685cb533c to your computer and use it in GitHub Desktop.
Save kevindoran/fce2a03739940a193351247685cb533c to your computer and use it in GitHub Desktop.
red_or_green_discrepancy
def filter_in_box(ps, l, u):
"""Fiter in points that lie inside rectangle."""
res = ps[(ps[:, 0] >= l[0]) & (ps[:, 0] < u[0]) & (ps[:, 1] >= l[1]) & (ps[:, 1] < u[1])]
return res
def discrepancy(ps, l, u, method):
disc = scipy.stats.qmc.discrepancy(
scipy.stats.qmc.scale(filter_in_box(ps, l, u), l, u, reverse=True),
method=method)
return disc
def discrepancy_rand(n, l, u, method, n_trials=int(100e3)):
rng = np.random.default_rng()
xs = [rng.uniform(l[0], u[0], n).astype(int) for t in range(n_trials)]
ys = [rng.uniform(l[1], u[1], n).astype(int) for t in range(n_trials)]
samples = np.stack([xs, ys], axis=-1)
assert samples.shape == (n_trials, n, 2)
discs = [discrepancy(ps, l, u, method) for ps in samples]
return discs
def discrepancy_fig():
methods = ["CD", "WD", "MD", "L2-star"]
bottom_left = (105, 105)
upper_right = (495, 495)
fig, axs = plt.subplots(1, 4, figsize=(20,4))
for i, method in enumerate(methods):
discs = discrepancy_rand(len(reds), bottom_left, upper_right, method)
axs[i].hist(discs, bins=100)
dreds = discrepancy(reds, bottom_left, upper_right, method)
dgreens = discrepancy(greens, bottom_left, upper_right, method)
axs[i].axvline(dreds, color='red', linestyle='--')
axs[i].axvline(dgreens, color='green', linestyle='--')
axs[i].set_title(method)
fig.show()
discrepancy_fig()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment