Created
October 24, 2022 17:41
-
-
Save adamhaber/0556671340e0daa9e2c6e3fd535cd992 to your computer and use it in GitHub Desktop.
psislw function ported to jax
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division # For Python 2 compatibility | |
import jax.numpy as np | |
import jax | |
from functools import partial | |
from functools import partial | |
@partial(jax.jit, static_argnums=(1,2,3)) | |
def psis_single(x, cutoff_ind, cutoffmin, k_min): | |
# improve numerical accuracy | |
x -= np.max(x) | |
# sort the array | |
sorted_x_idx = np.sort(x) | |
# divide log weights into body and right tail | |
xcutoff = np.maximum(sorted_x_idx[cutoff_ind], cutoffmin) | |
expxcutoff = np.exp(xcutoff) | |
x2 = np.where(x > xcutoff, x, np.nan) | |
n2 = np.where(x > xcutoff, 1, 0).sum() | |
x2si = np.argsort(x2) | |
x2 = np.exp(x2) | |
x2 -= expxcutoff | |
k, sigma = gpdfitnew(np.sort(x2), n2) | |
k = np.where(n2 <= 4, np.inf, k) | |
sti = np.arange(0.5, x.size) | |
sti = np.where(sti<n2, sti/n2, 0) | |
qq = gpinv(sti, k, sigma) | |
qq += expxcutoff | |
qq = np.log(qq)[np.argsort(x2si)] | |
x_good = np.where(x>xcutoff, qq, x) | |
x_good = np.where(x_good > 0, 0, x_good) | |
# renormalize weights | |
cond = (k >= k_min).astype(int) * (1-np.isinf(k).astype(int)) | |
x = x_good * cond + x * (1 - cond) | |
x -= sumlogs(x) | |
return np.append(x, k) | |
def psislw(lw): | |
Reff=1.0 | |
overwrite_lw=False | |
if lw.ndim == 2: | |
n, m = lw.shape | |
elif lw.ndim == 1: | |
n = len(lw) | |
m = 1 | |
else: | |
raise ValueError("Argument `lw` must be 1 or 2 dimensional.") | |
if n <= 1: | |
raise ValueError("More than one log-weight needed.") | |
if overwrite_lw and lw.flags.f_contiguous: | |
# in-place operation | |
lw_out = lw | |
else: | |
# allocate new array for output | |
lw_out = np.empty_like(lw) # np.copy(lw, order="F") | |
# allocate output array for kss | |
kss = np.empty(m) | |
# precalculate constants | |
cutoff_ind = -int(np.ceil(min(0.2 * n, 3 * np.sqrt(n / Reff)))) - 1 | |
cutoffmin = float(np.log(np.finfo(float).tiny)) | |
k_min = 1 / 3 | |
res = jax.vmap(lambda x : psis_single(x, cutoff_ind, cutoffmin, k_min))(lw.T) | |
return res[:,:-1].T, res[:,-1] | |
@jax.jit | |
def gpdfitnew(x, n, max_bs=65): | |
if x.ndim != 1 or len(x) <= 1: | |
raise ValueError("Invalid input array.") | |
x = np.sort(x) | |
PRIOR = 3 | |
m = 30 + np.sqrt(n).astype(int) | |
bs = np.arange(1, max_bs, dtype=float) | |
idx = np.where(bs<=m, bs, m+1) | |
bs -= 0.5 | |
bs = np.divide(m, bs) | |
bs = np.sqrt(bs) | |
bs = np.subtract(1, bs) | |
bs = np.where(bs < (1-np.sqrt(m/(m-0.5))), bs, np.nan) | |
bs /= PRIOR * x[(n / 4 + 0.5).astype(int) - 1] | |
bs += 1 / np.nanmax(x) | |
ks = np.negative(bs) | |
temp = ks[:, None] * x | |
temp = np.log1p(temp) | |
ks = np.nanmean(np.where(idx<=m, temp.T, np.nan).T, axis=1) | |
L = bs / ks | |
L = np.negative(L) | |
L = np.log(L) | |
L -= ks | |
L -= 1 | |
L *= n | |
temp = L - L[:, None] | |
temp = np.exp(temp) | |
w = np.nansum(temp, axis=1) | |
w = np.divide(1, w) | |
w = np.where(np.isfinite(w), w, 0) | |
# remove negligible weights | |
dii = w >= 10 * np.finfo(float).eps | |
# normalise w | |
w /= np.where(dii & (idx<=m), w, 0).sum() | |
# posterior mean for b | |
b = np.where(dii & (idx<=m), bs * w, 0).sum() | |
# Estimate for k, note that we return a negative of Zhang and | |
# Stephens's k, because it is more common parameterisation. | |
temp = (-b) * x | |
temp = np.log1p(temp) | |
k = np.nanmean(temp) | |
# estimate for sigma | |
sigma = -k / b * n / (n - 0) | |
# weakly informative prior for k | |
a = 10 | |
k = k * n / (n + a) + a * 0.5 / (n + a) | |
return k, sigma | |
@jax.jit | |
def gpinv(p, k, sigma): | |
"""Inverse Generalised Pareto distribution function.""" | |
x = np.empty(p.shape) | |
nanx = x.at[:].set(np.nan) | |
ok = (p > 0) & (p < 1) | |
x_final_all_ok = np.where( | |
np.abs(k) < np.finfo(float).eps, | |
np.negative(np.log1p(np.negative(p))), | |
np.expm1(-k*np.log1p(np.negative(p)))/k | |
) * sigma | |
temp = np.where( | |
np.abs(k) < np.finfo(float).eps, | |
np.negative(np.log1p(np.negative(p))), | |
np.expm1(-k*np.log1p(np.negative(p)))/k | |
) | |
temp2 = np.where( | |
ok, | |
temp, | |
0 | |
) | |
x = np.where(ok, temp2 * sigma, x) | |
x = np.where(p == 0, 0, x) | |
x_kpos = np.where(p == 1, np.inf, x) | |
x_kneg = np.where(p == 1, -sigma / k, x) | |
x_final_not_all_ok = np.where(k >= 0, x_kpos, x_kneg) | |
return np.where(sigma <= 0, nanx, np.where(np.all(ok), x_final_all_ok, x_final_not_all_ok)) | |
@jax.jit | |
def sumlogs(x, axis=None, out=None): | |
maxx = x.max(axis=axis, keepdims=True) | |
xnorm = x - maxx | |
xnorm = np.exp(xnorm) | |
out = np.sum(xnorm, axis=axis) | |
out = np.log(out) | |
out += np.squeeze(maxx) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment