Skip to content

Instantly share code, notes, and snippets.

@michaellindon
Created March 12, 2024 17:20
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save michaellindon/5ce04c744d20755c3f653fbb58c2f4dd to your computer and use it in GitHub Desktop.
Save michaellindon/5ce04c744d20755c3f653fbb58c2f4dd to your computer and use it in GitHub Desktop.
Anytime-Valid Multinomial Count Sequential p-Value
from scipy.special import loggamma, gammaln, xlogy
import numpy as np
def sequential_p_value(counts, assignment_probabilities, dirichlet_alpha=None):
"""
Compute the sequential p-value for given counts and assignment probabilities.
Lindon, Michael, and Alan Malek.
"Anytime-Valid Inference For Multinomial Count Data."
In Advances in Neural Information Processing Systems, 2022
https://openreview.net/pdf?id=a4zg0jiuVi
Parameters
----------
counts : array like
The observed counts in each treatment group.
assignment_probabilities : array like
The assignment probabilities to each treatment group.
dirichlet_alpha : array like, optional
The Dirichlet mixture parameter.
Returns
-------
float
The sequential p-value.
"""
counts = np.array(counts)
assignment_probabilities = np.array(assignment_probabilities)
if dirichlet_alpha is None:
dirichlet_alpha = 100 * assignment_probabilities
else:
dirichlet_alpha = np.array(dirichlet_alpha)
lm1 = (
loggamma(counts.sum() + 1)
- loggamma(counts + 1).sum()
+ loggamma(dirichlet_alpha.sum())
- loggamma(dirichlet_alpha).sum()
+ loggamma(dirichlet_alpha + counts).sum()
- loggamma((dirichlet_alpha + counts).sum())
)
lm0 = gammaln(counts.sum() + 1) + np.sum(
xlogy(counts, assignment_probabilities) - gammaln(counts + 1), axis=-1
)
return min(1, np.exp(lm0 - lm1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment