Last active
July 16, 2021 12:53
-
-
Save olszewskip/3e677477cee1b74f305ef3a032f92ffe to your computer and use it in GitHub Desktop.
Hail_ordinal_reg_dummy_1.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
rng = np.random.default_rng(seed=42) | |
import hail as hl | |
hl.init() | |
# dummy data generation | |
def sigmoid(arr): | |
result = np.empty(arr.shape, dtype='float') | |
positive = arr > 0 | |
negative = ~positive | |
result[positive] = 1 / (1 + np.exp(-arr[positive])) | |
exp_arr = np.exp(arr[negative]) | |
result[negative] = exp_arr / (exp_arr + 1) | |
return result | |
k = 3 | |
m = 4 | |
n = 123 | |
_beta = np.array([0.1, 0.2, 1, 2]) | |
_theta = np.array([10, 15]) | |
X = rng.uniform(0, 8, (n, m)) | |
_Xbeta = X @ _beta | |
_p = sigmoid(_theta[None, :] - _Xbeta[:, None] + 2 * rng.standard_normal(n)[:, None]) | |
y = 1 + (_p < 0.5).sum(axis=1) | |
# some definitions for Hail | |
hl_y = hl.nd.array(y) | |
hl_X = hl.nd.array(X) | |
def hl_gte(nd_left, nd_right): | |
# >= | |
return hl.map( | |
lambda _: hl.float(_>=0), | |
(nd_left - nd_right) | |
) | |
def hl_scalar_sigmoid(x): | |
return hl.if_else( | |
x > 0, | |
1 / (1 + hl.exp(-x)), | |
hl.rbind( | |
hl.exp(x), | |
lambda _: _ / (1 + _) | |
) | |
) | |
def hl_sigmoid(hl_nd): | |
return hl_nd.map(hl_scalar_sigmoid) | |
def hl_log(nd): | |
return nd.map(lambda _: hl.log(_)) | |
def hl_logit(nd): | |
return hl_log(nd / (1 - nd)) | |
def hl_mean(nd, dim): | |
return nd.sum(axis=dim) / nd.shape[dim] | |
def hl_diff(nd): | |
return nd[1:] - nd[:-1] | |
def hl_get_t(k): | |
# lower triangle, np.tril | |
i = hl.nd.arange(1, k) | |
return hl_gte( | |
i.reshape(-1, 1), | |
i.reshape(1, -1) | |
) | |
def hl_get_s(y, k): | |
return 2 * hl_gte( | |
hl.nd.arange(1, k).reshape(1, -1), | |
y.reshape(-1, 1) | |
) - 1 | |
def hl_maximum(nd_left, nd_right): | |
return hl.nd.array(hl.zip( | |
nd_left._data_array(), | |
nd_right._data_array() | |
).map( | |
lambda _: hl.max(*_) | |
)) | |
def hl_get_gamma_0(y, k, m): | |
t = hl_gte( | |
hl.nd.arange(1, k)[:, None], | |
y[None, :] | |
) | |
theta = hl_logit( | |
hl_mean(t, 1) | |
) | |
eta = hl_diff(theta) | |
return hl.nd.hstack([ | |
theta[:1], | |
eta, | |
hl.nd.zeros(m) | |
]) | |
hl_get_eta_beta_0 = hl_get_gamma_0 | |
def hl_get_gamma_lower_bound(k, m): | |
return hl.nd.hstack([ | |
hl.nd.array([-np.inf]), | |
hl.nd.array([0] * (k-2), dtype=hl.dtype('float')), | |
hl.nd.array([-np.inf] * m) | |
]) | |
hl_get_eta_beta_lower_bound = hl_get_gamma_lower_bound | |
# ordinal reg in hail | |
def hl_get_Xdot(X, k): | |
# concat([t[None, :, :], -Xdot[:, None, :]], axis=2) | |
t = hl_get_t(k) | |
return hl.nd.concatenate([ | |
hl.nd.concatenate( | |
hl.range(hl.int(X.shape[0])).map( | |
lambda _: t[None, :, :] | |
), axis=0 | |
), | |
-hl.nd.concatenate( | |
hl.range(hl.int(t.shape[0])).map( | |
lambda _: X[:, None, :] | |
), axis=1 | |
) | |
], axis=2) | |
def hl_get_hess_negjac(Xdot, s, p): | |
_1 = ((1 - p) * s)[:, :, None] * Xdot | |
negjac = _1.sum(axis=(0, 1)) | |
_2 = (p * (1 - p))[:, :, None, None] * Xdot[:, :, :, None] * Xdot[:, :, None, :] | |
hess = _2.sum(axis=(0, 1)) | |
return hess, negjac | |
def _hl_get_gamma(_, gamma, gamma_lower_bound, Xdot, s, tol): | |
p = hl_sigmoid( | |
s * (Xdot @ gamma) | |
) | |
delta_gamma = hl.nd.solve( | |
*hl_get_hess_negjac(Xdot, s, p) | |
) | |
gamma = gamma + delta_gamma | |
gamma = hl_maximum(gamma, gamma_lower_bound) | |
converged = hl.max( | |
hl.abs(delta_gamma)._data_array() | |
) < tol | |
return hl.if_else( | |
converged, | |
gamma, | |
_(gamma, gamma_lower_bound, Xdot, s, tol) | |
) | |
def hl_get_gamma(X, y, tol, k, m): | |
gamma_0 = hl_get_gamma_0(y, k, m) | |
gamma_lower_bound = hl_get_gamma_lower_bound(k, m) | |
Xdot = hl_get_Xdot(X, k) | |
s = hl_get_s(y, k) | |
return hl.experimental.loop( | |
_hl_get_gamma, | |
hl.dtype('ndarray<float64, 1>'), | |
gamma_0, gamma_lower_bound, Xdot, s, tol | |
) | |
## looping | |
# hl_gamma = hl_get_gamma(hl_X, hl_y, 1e-5, k, m) | |
# print(hl_gamma.take(1)[0]) | |
# first `loop` iteration | |
hl_gamma_0 = hl_get_gamma_0(hl_y, k, m) | |
hl_gamma_lower_bound = hl_get_gamma_lower_bound(k, m) | |
hl_Xdot = hl_get_Xdot(hl_X, k) | |
hl_s = hl_get_s(hl_y, k) | |
hl_p_0 = hl_sigmoid( | |
hl_s * (hl_Xdot @ hl_gamma_0) | |
) | |
hl_delta_gamma_0 = hl.nd.solve( | |
*hl_get_hess_negjac(hl_Xdot, hl_s, hl_p_0) | |
) | |
print(hl_delta_gamma_0.take(1)[0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment