-
vmap
is a very general function, but likeeinsum
, I end up trying a bunch of permutations before it works the way I want. More documentation and examples, or higher order functions, would be helpful. -
debugging is much more difficult than in autograd. Ex: tracking down NaNs is harder, and inspecting variables in jax is not possible?
-
It's as fast as advertised.
jit
is pretty impressive. -
stax
is a neat little sublibrary, I'd like to see more developer there, but I understand the possible scope-creep. -
I love the idea of riding the upgrade train of Jax, XLA and GPUs.
-
I see a lot¹ of examples using internal Jax APIs and my code doesn't, so that gives me pause. Am I missing something, or are more higher order functions needed?
-
🆕 Is
vectorize
the right API? I'm not sure. Perhaps some common patterns could be extracted into functions. I had a lot of trouble with trying to duplicateelementwise_grad
ingrad
+vmap
primitives. It was much easier in autograd. After reading h
""" | |
I consider this a greedy algorithm, since at each time step, I ask which is a better "twist". I don't think it's optimal. | |
The idea is to estimate the probability of discovering tau in the next time step, given your current position and knowledge (position being left or right, denoted 1 and 2 here). We calculate the probability of discovering tau in the next time step as follows: | |
t1 is the max time observed in position 1, and t2 in position 2. Denote P the random variable of which position tau is in (1 or 2). Small p is our current position. Suppose we start in position 1, i.e. p=1 | |
Pr(discover tau in next delta time| t1, t2, p=1) = | |
Pr(discover tau in next delta time| t1, t2, P=1, p=1) * Pr(P=1) + |
MAX = 114 | |
MIN = 22 | |
def falling_factorial(n, k): | |
""" | |
computes (n) * (n-1) * ... * (n-(k-1)) | |
""" | |
running_product = n | |
for p in range(n-1, n-k, -1): | |
running_product *= p |
from zepid.base import create_spline_transform | |
from lifelines import CoxPHFitter | |
from lifelines.datasets import load_rossi | |
rossi = load_rossi() | |
rossi_with_splines = rossi.copy() | |
spline_transform, bp = create_spline_transform(rossi_with_splines['age'], term=3, restricted=False) | |
rossi_with_splines[['age0', 'age1', 'age2']] = pd.DataFrame(spline_transform(rossi_with_splines['age'])) | |
rossi_with_splines = rossi_with_splines.drop('age', axis=1) |
from lifelines import WeibullFitter | |
lambda_, rho_ = 2, 0.5 | |
N = 10_000 | |
T_actual = lambda_ * np.random.exponential(1, size=N)**(1/rho_) | |
T_censor = lambda_ * np.random.exponential(1, size=N)**(1/rho_) | |
T = np.minimum(T_actual, T_censor) | |
E = T_actual < T_censor |
One of the reasons I'm really excited about autograd is because it enables me to be able to transform my abstract parameters into business-logic. Let me explain with an example. Suppose I am modeling customer churn, and I have fitted a Weibull survival model using maximum likelihood estimation. I have two parameter estimates: lambda-hat and rho-hat. I also have their covariance matrix, which tells me how much uncertainty is present in the estimates (in lifelines, this is under the variance_matrix_
property. From this, I can plot the survival curve, which is fine, but what I really want is a measure of lifetime value. Customers give us $10 each time period for the first 3 time periods, and then $30 a month afterwards. For a single user, the average LTV (up to timeline
) calculation might look like:
# create a Weibull model with fake data
wf = WeibullFitter().fit(np.arange(1, 100))
from autograd import numpy as np
import numpy as np | |
def piecewise_exponential_survival_data(n, breakpoints, lambdas): | |
""" | |
Examples | |
-------- | |
>>> T = piecewise_exponential_survival_data(100000, [1, 3], [0.2, 3, 1.]) | |
>>> NelsonAalenFitter().fit(T).plot() |
for _ in range(10): | |
v = np.random.uniform(-10, 10, size=2) | |
print(check_grad(_negative_log_likelihood, gradient_function, v, log(T), E), v) | |
""" | |
0.4712740782685882 [-6.56058891 5.14902935] | |
3237.6524450846837 [ 9.08538372 -1.41142257] |
N | frac | batch | single | ratio | |
---|---|---|---|---|---|
432 | 0.01 | 0.17586684226989746 | 0.24918293952941895 | 0.7057740092560965 | |
432 | 0.09909090909090909 | 0.139909029006958 | 0.18905401229858398 | 0.7400479223154045 | |
432 | 0.1881818181818182 | 0.18714380264282227 | 0.19823789596557617 | 0.9440364655369406 | |
432 | 0.2772727272727273 | 0.15966224670410156 | 0.1492321491241455 | 1.0698917601949116 | |
432 | 0.3663636363636364 | 0.1783008575439453 | 0.2960469722747803 | 0.6022721873286135 | |
432 | 0.4554545454545455 | 0.2235269546508789 | 0.2482771873474121 | 0.9003120948768426 | |
432 | 0.5445454545454546 | 0.2755610942840576 | 0.27878808975219727 | 0.9884249163190293 | |
432 | 0.6336363636363637 | 0.2598390579223633 | 0.19139719009399414 | 1.3575907660648399 | |
432 | 0.7227272727272728 | 0.28580784797668457 | 0.19529199600219727 | 1.4634898194878856 |