Skip to content

Instantly share code, notes, and snippets.

@olszewskip
Last active July 16, 2021 12:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save olszewskip/3e677477cee1b74f305ef3a032f92ffe to your computer and use it in GitHub Desktop.
Save olszewskip/3e677477cee1b74f305ef3a032f92ffe to your computer and use it in GitHub Desktop.
Hail_ordinal_reg_dummy_1.py
import numpy as np
rng = np.random.default_rng(seed=42)
import hail as hl
hl.init()
# dummy data generation
def sigmoid(arr):
result = np.empty(arr.shape, dtype='float')
positive = arr > 0
negative = ~positive
result[positive] = 1 / (1 + np.exp(-arr[positive]))
exp_arr = np.exp(arr[negative])
result[negative] = exp_arr / (exp_arr + 1)
return result
k = 3
m = 4
n = 123
_beta = np.array([0.1, 0.2, 1, 2])
_theta = np.array([10, 15])
X = rng.uniform(0, 8, (n, m))
_Xbeta = X @ _beta
_p = sigmoid(_theta[None, :] - _Xbeta[:, None] + 2 * rng.standard_normal(n)[:, None])
y = 1 + (_p < 0.5).sum(axis=1)
# some definitions for Hail
hl_y = hl.nd.array(y)
hl_X = hl.nd.array(X)
def hl_gte(nd_left, nd_right):
# >=
return hl.map(
lambda _: hl.float(_>=0),
(nd_left - nd_right)
)
def hl_scalar_sigmoid(x):
return hl.if_else(
x > 0,
1 / (1 + hl.exp(-x)),
hl.rbind(
hl.exp(x),
lambda _: _ / (1 + _)
)
)
def hl_sigmoid(hl_nd):
return hl_nd.map(hl_scalar_sigmoid)
def hl_log(nd):
return nd.map(lambda _: hl.log(_))
def hl_logit(nd):
return hl_log(nd / (1 - nd))
def hl_mean(nd, dim):
return nd.sum(axis=dim) / nd.shape[dim]
def hl_diff(nd):
return nd[1:] - nd[:-1]
def hl_get_t(k):
# lower triangle, np.tril
i = hl.nd.arange(1, k)
return hl_gte(
i.reshape(-1, 1),
i.reshape(1, -1)
)
def hl_get_s(y, k):
return 2 * hl_gte(
hl.nd.arange(1, k).reshape(1, -1),
y.reshape(-1, 1)
) - 1
def hl_maximum(nd_left, nd_right):
return hl.nd.array(hl.zip(
nd_left._data_array(),
nd_right._data_array()
).map(
lambda _: hl.max(*_)
))
def hl_get_gamma_0(y, k, m):
t = hl_gte(
hl.nd.arange(1, k)[:, None],
y[None, :]
)
theta = hl_logit(
hl_mean(t, 1)
)
eta = hl_diff(theta)
return hl.nd.hstack([
theta[:1],
eta,
hl.nd.zeros(m)
])
hl_get_eta_beta_0 = hl_get_gamma_0
def hl_get_gamma_lower_bound(k, m):
return hl.nd.hstack([
hl.nd.array([-np.inf]),
hl.nd.array([0] * (k-2), dtype=hl.dtype('float')),
hl.nd.array([-np.inf] * m)
])
hl_get_eta_beta_lower_bound = hl_get_gamma_lower_bound
# ordinal reg in hail
def hl_get_Xdot(X, k):
# concat([t[None, :, :], -Xdot[:, None, :]], axis=2)
t = hl_get_t(k)
return hl.nd.concatenate([
hl.nd.concatenate(
hl.range(hl.int(X.shape[0])).map(
lambda _: t[None, :, :]
), axis=0
),
-hl.nd.concatenate(
hl.range(hl.int(t.shape[0])).map(
lambda _: X[:, None, :]
), axis=1
)
], axis=2)
def hl_get_hess_negjac(Xdot, s, p):
_1 = ((1 - p) * s)[:, :, None] * Xdot
negjac = _1.sum(axis=(0, 1))
_2 = (p * (1 - p))[:, :, None, None] * Xdot[:, :, :, None] * Xdot[:, :, None, :]
hess = _2.sum(axis=(0, 1))
return hess, negjac
def _hl_get_gamma(_, gamma, gamma_lower_bound, Xdot, s, tol):
p = hl_sigmoid(
s * (Xdot @ gamma)
)
delta_gamma = hl.nd.solve(
*hl_get_hess_negjac(Xdot, s, p)
)
gamma = gamma + delta_gamma
gamma = hl_maximum(gamma, gamma_lower_bound)
converged = hl.max(
hl.abs(delta_gamma)._data_array()
) < tol
return hl.if_else(
converged,
gamma,
_(gamma, gamma_lower_bound, Xdot, s, tol)
)
def hl_get_gamma(X, y, tol, k, m):
gamma_0 = hl_get_gamma_0(y, k, m)
gamma_lower_bound = hl_get_gamma_lower_bound(k, m)
Xdot = hl_get_Xdot(X, k)
s = hl_get_s(y, k)
return hl.experimental.loop(
_hl_get_gamma,
hl.dtype('ndarray<float64, 1>'),
gamma_0, gamma_lower_bound, Xdot, s, tol
)
## looping
# hl_gamma = hl_get_gamma(hl_X, hl_y, 1e-5, k, m)
# print(hl_gamma.take(1)[0])
# first `loop` iteration
hl_gamma_0 = hl_get_gamma_0(hl_y, k, m)
hl_gamma_lower_bound = hl_get_gamma_lower_bound(k, m)
hl_Xdot = hl_get_Xdot(hl_X, k)
hl_s = hl_get_s(hl_y, k)
hl_p_0 = hl_sigmoid(
hl_s * (hl_Xdot @ hl_gamma_0)
)
hl_delta_gamma_0 = hl.nd.solve(
*hl_get_hess_negjac(hl_Xdot, hl_s, hl_p_0)
)
print(hl_delta_gamma_0.take(1)[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment