Skip to content

Instantly share code, notes, and snippets.

@dlubarov
Created February 7, 2023 17:35
Show Gist options
  • Save dlubarov/d5b4f40a7af9d3b622a24002606ab863 to your computer and use it in GitHub Desktop.
Save dlubarov/d5b4f40a7af9d3b622a24002606ab863 to your computer and use it in GitHub Desktop.
def H(x):
assert 0 < x < 1
return -x * log(x)/log(2) - (1 - x) * log(1 - x)/log(2)
def D(a, p):
assert 0 < a < 1
assert 0 < p < 1
return a*log(a/p) + (1 - a)*(log(1 - a) - log(1 - p))
def log_pr_exactly_c_cols(n, m, d, c):
assert m > d
assert m > c
if c == d:
pass # TODO
p = log_binom_ub(m, c) + n * (log_binom_ub(c, d) - log_binom_lb(m, d))
return min(p, 0)
def log_pr_sum_low_weight_given_c_cols(n, m, d, c, zeta):
if c <= zeta:
return 0 # Pr = 1
p = (q - 2) / (q - 1)
return log_binom_cdf_ub(zeta, c, p)
def log_pr_exactly_c_cols_and_sum_low_weight(n, m, d, c, zeta):
p1 = log_pr_exactly_c_cols(n, m, d, c)
p2 = log_pr_sum_low_weight_given_c_cols(n, m, d, c, zeta)
return p1 + p2
def log_pr_xM_low_weight(m, d, epsilon, zeta):
p_row_misses_col = ((m - 1) / m)^d # TODO: approx
p_all_rows_miss_col = p_row_misses_col^epsilon
p_col_nonzero = 1 - p_all_rows_miss_col
#p_nonzero_col_has_nonzero_sum = (q - 2) / (q - 1)
p = p_col_nonzero #* p_nonzero_col_has_nonzero_sum
# TODO: Not independent variables...
return log_binom_cdf_ub(zeta, m, p)
## TODO old
s = 0
step=int(m/20)
for c in range(d + 1, m + 1, step): # TODO: Start at d, not d+1
log_pr = log_pr_exactly_c_cols_and_sum_low_weight(epsilon, m, d, c, zeta)
print(float(log_pr))
s += exp(log_pr)
return log(s)
def log_pr_any_xM_low_weight(n, m, d, epsilon, zeta):
p_per_input = log_pr_xM_low_weight(m, d, epsilon, zeta)
log_num_inputs = log_binom_ub(n, epsilon) + epsilon * log(q - 1)
print(int(p_per_input), int(log_num_inputs))
return p_per_input + log_num_inputs
def log_binom_ub(n, k):
return n * H(k/n) + (log(n) - log(8) - log(k) - log(n - k)) / 2
def log_binom_lb(n, k):
return n * H(k/n) - log(n + 1) + (log(n) - log(pi) - log(k) - log(n - k)) / 2
def log_binom_cdf_ub(k, n, p):
return -n * D(k/n, p)
def b_prime(k):
return (b + k/n + ((r - 1) + 110/n) / math.log2(q)) * n
q = 2^127 - 1
r = 1.72
a = 0.238
b = 0.1205
n = int(1e6)
d = 20
rows = int(a * r * n)
cols = int((r - 1 - r*a) * n)
float(log_pr_any_xM_low_weight(rows, cols, d, .07 * a * r * n, .12 * n))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment