vmapis a very general function, but like
einsum, I end up trying a bunch of permutations before it works the way I want. More documentation and examples, or higher order functions, would be helpful.
debugging is much more difficult than in autograd. Ex: tracking down NaNs is harder, and inspecting variables in jax is not possible?
It's as fast as advertised.
jitis pretty impressive.
staxis a neat little sublibrary, I'd like to see more developer there, but I understand the possible scope-creep.
I love the idea of riding the upgrade train of Jax, XLA and GPUs.
I see a lot¹ of examples using internal Jax APIs and my code doesn't, so that gives me pause. Am I missing something, or are more higher order functions needed?
vectorizethe right API? I'm not sure. Perhaps some common patterns could be extracted into functions. I had a lot of trouble with trying to duplicate
vmapprimitives. It was much easier in autograd. After reading h
|# test previous algorithm|
|actuals = pd.read_csv("https://gist.githubusercontent.com/csaid/a57c4ebaa1c7b0671cdc9692638ea4c4/raw/ad1709938834d7bc88b62ff0763733502eb6a329/shower_problem_tau_samples.csv")|
|DELTA = 0.1|
|def survival_function(t, lambda_=50., rho=1.5):|
|# Assume simple Weibull model|
|return np.exp(-(t/lambda_) ** rho)|
|I consider this a greedy algorithm, since at each time step, I ask which is a better "twist". I don't think it's optimal.|
|The idea is to estimate the probability of discovering tau in the next time step, given your current position and knowledge (position being left or right, denoted 1 and 2 here). We calculate the probability of discovering tau in the next time step as follows:|
|t1 is the max time observed in position 1, and t2 in position 2. Denote P the random variable of which position tau is in (1 or 2). Small p is our current position. Suppose we start in position 1, i.e. p=1|
|Pr(discover tau in next delta time| t1, t2, p=1) =|
|Pr(discover tau in next delta time| t1, t2, P=1, p=1) * Pr(P=1) +|
|from zepid.base import create_spline_transform|
|from lifelines import CoxPHFitter|
|from lifelines.datasets import load_rossi|
|rossi = load_rossi()|
|rossi_with_splines = rossi.copy()|
|spline_transform, bp = create_spline_transform(rossi_with_splines['age'], term=3, restricted=False)|
|rossi_with_splines[['age0', 'age1', 'age2']] = pd.DataFrame(spline_transform(rossi_with_splines['age']))|
|rossi_with_splines = rossi_with_splines.drop('age', axis=1)|