Skip to content

Instantly share code, notes, and snippets.

@adamhaber
Created October 24, 2022 17:41
Show Gist options
  • Save adamhaber/0556671340e0daa9e2c6e3fd535cd992 to your computer and use it in GitHub Desktop.
Save adamhaber/0556671340e0daa9e2c6e3fd535cd992 to your computer and use it in GitHub Desktop.
psislw function ported to jax
from __future__ import division # For Python 2 compatibility
import jax.numpy as np
import jax
from functools import partial
from functools import partial
@partial(jax.jit, static_argnums=(1,2,3))
def psis_single(x, cutoff_ind, cutoffmin, k_min):
# improve numerical accuracy
x -= np.max(x)
# sort the array
sorted_x_idx = np.sort(x)
# divide log weights into body and right tail
xcutoff = np.maximum(sorted_x_idx[cutoff_ind], cutoffmin)
expxcutoff = np.exp(xcutoff)
x2 = np.where(x > xcutoff, x, np.nan)
n2 = np.where(x > xcutoff, 1, 0).sum()
x2si = np.argsort(x2)
x2 = np.exp(x2)
x2 -= expxcutoff
k, sigma = gpdfitnew(np.sort(x2), n2)
k = np.where(n2 <= 4, np.inf, k)
sti = np.arange(0.5, x.size)
sti = np.where(sti<n2, sti/n2, 0)
qq = gpinv(sti, k, sigma)
qq += expxcutoff
qq = np.log(qq)[np.argsort(x2si)]
x_good = np.where(x>xcutoff, qq, x)
x_good = np.where(x_good > 0, 0, x_good)
# renormalize weights
cond = (k >= k_min).astype(int) * (1-np.isinf(k).astype(int))
x = x_good * cond + x * (1 - cond)
x -= sumlogs(x)
return np.append(x, k)
def psislw(lw):
Reff=1.0
overwrite_lw=False
if lw.ndim == 2:
n, m = lw.shape
elif lw.ndim == 1:
n = len(lw)
m = 1
else:
raise ValueError("Argument `lw` must be 1 or 2 dimensional.")
if n <= 1:
raise ValueError("More than one log-weight needed.")
if overwrite_lw and lw.flags.f_contiguous:
# in-place operation
lw_out = lw
else:
# allocate new array for output
lw_out = np.empty_like(lw) # np.copy(lw, order="F")
# allocate output array for kss
kss = np.empty(m)
# precalculate constants
cutoff_ind = -int(np.ceil(min(0.2 * n, 3 * np.sqrt(n / Reff)))) - 1
cutoffmin = float(np.log(np.finfo(float).tiny))
k_min = 1 / 3
res = jax.vmap(lambda x : psis_single(x, cutoff_ind, cutoffmin, k_min))(lw.T)
return res[:,:-1].T, res[:,-1]
@jax.jit
def gpdfitnew(x, n, max_bs=65):
if x.ndim != 1 or len(x) <= 1:
raise ValueError("Invalid input array.")
x = np.sort(x)
PRIOR = 3
m = 30 + np.sqrt(n).astype(int)
bs = np.arange(1, max_bs, dtype=float)
idx = np.where(bs<=m, bs, m+1)
bs -= 0.5
bs = np.divide(m, bs)
bs = np.sqrt(bs)
bs = np.subtract(1, bs)
bs = np.where(bs < (1-np.sqrt(m/(m-0.5))), bs, np.nan)
bs /= PRIOR * x[(n / 4 + 0.5).astype(int) - 1]
bs += 1 / np.nanmax(x)
ks = np.negative(bs)
temp = ks[:, None] * x
temp = np.log1p(temp)
ks = np.nanmean(np.where(idx<=m, temp.T, np.nan).T, axis=1)
L = bs / ks
L = np.negative(L)
L = np.log(L)
L -= ks
L -= 1
L *= n
temp = L - L[:, None]
temp = np.exp(temp)
w = np.nansum(temp, axis=1)
w = np.divide(1, w)
w = np.where(np.isfinite(w), w, 0)
# remove negligible weights
dii = w >= 10 * np.finfo(float).eps
# normalise w
w /= np.where(dii & (idx<=m), w, 0).sum()
# posterior mean for b
b = np.where(dii & (idx<=m), bs * w, 0).sum()
# Estimate for k, note that we return a negative of Zhang and
# Stephens's k, because it is more common parameterisation.
temp = (-b) * x
temp = np.log1p(temp)
k = np.nanmean(temp)
# estimate for sigma
sigma = -k / b * n / (n - 0)
# weakly informative prior for k
a = 10
k = k * n / (n + a) + a * 0.5 / (n + a)
return k, sigma
@jax.jit
def gpinv(p, k, sigma):
"""Inverse Generalised Pareto distribution function."""
x = np.empty(p.shape)
nanx = x.at[:].set(np.nan)
ok = (p > 0) & (p < 1)
x_final_all_ok = np.where(
np.abs(k) < np.finfo(float).eps,
np.negative(np.log1p(np.negative(p))),
np.expm1(-k*np.log1p(np.negative(p)))/k
) * sigma
temp = np.where(
np.abs(k) < np.finfo(float).eps,
np.negative(np.log1p(np.negative(p))),
np.expm1(-k*np.log1p(np.negative(p)))/k
)
temp2 = np.where(
ok,
temp,
0
)
x = np.where(ok, temp2 * sigma, x)
x = np.where(p == 0, 0, x)
x_kpos = np.where(p == 1, np.inf, x)
x_kneg = np.where(p == 1, -sigma / k, x)
x_final_not_all_ok = np.where(k >= 0, x_kpos, x_kneg)
return np.where(sigma <= 0, nanx, np.where(np.all(ok), x_final_all_ok, x_final_not_all_ok))
@jax.jit
def sumlogs(x, axis=None, out=None):
maxx = x.max(axis=axis, keepdims=True)
xnorm = x - maxx
xnorm = np.exp(xnorm)
out = np.sum(xnorm, axis=axis)
out = np.log(out)
out += np.squeeze(maxx)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment