Skip to content

Instantly share code, notes, and snippets.

@JohannesNE
Created April 8, 2024 08:25
Show Gist options
  • Select an option

  • Save JohannesNE/ce399b6d7422621b37e54a03711e2bce to your computer and use it in GitHub Desktop.

Select an option

Save JohannesNE/ce399b6d7422621b37e54a03711e2bce to your computer and use it in GitHub Desktop.
using Turing
using DataFrames
using CSV
using Tidier
using Plots
using StatsPlots
using BSplines
sample_data_path = "https://gist.githubusercontent.com/JohannesNE/df539760195619c642b0ea43b07c6ca3/raw/bb72096912b07b5221bd4a2202eac0ec7c9121a4/acc_apnea.csv"
acc = DataFrame(CSV.File(download(sample_data_path)))
acc10s = @filter(acc, sec < 10)
Y10s = @pull(acc10s, z)
X10s = @pull(acc10s, sec)
# generated_quantities returns an array of tuples. This can be unpacked to a matrix
function gq2matrix(gq, selector)
reduce(vcat, getindex(samp, selector)' for samp in gq[:,1])
end
function plot_post_pred!(pred_draws::Array, x, n = 100)
plot()
for i in 1:n
plot!(x, pred_draws[i,:], legend=false, alpha = 0.2)
end
epred_mean = mean.(eachcol(pred_draws))
plot!(x, epred_mean, lw = 3, color = :blue)
current()
end
function get_time_since_event(time_vec, event_vec)
ann_index = zeros(eltype(event_vec), length(time_vec))
i_ann::Int64 = 1
for i in eachindex(time_vec)
if i_ann < length(event_vec) && time_vec[i] > event_vec[i_ann + 1]
i_ann += 1
end
ann_index[i] = time_vec[i] - event_vec[i_ann]
end
ann_index
end
# Spline setup
num_knots_beat = 30
num_knots_trend = 20
# knots_list = quantile(X1s, range(0, 1; length=num_knots))
knots_list_beat = range(0, 1; length=num_knots_beat)
knots_list_trend = range(0, last(X10s); length=num_knots_trend)
basis_beat = BSplineBasis(3, knots_list_beat)
basis_trend = BSplineBasis(3, knots_list_trend)
@model function full_spline_regression(x, y)
obs_len_s = last(x) - first(x)
# Set prior for beat interval
mean_beat_interval ~ LogNormal(log(1), 0.4)
rel_var_beat_interval ~ Exponential(0.1)
n_beats ~ Poisson((1/mean_beat_interval) * obs_len_s) # Not currently used
# Set variance prior.
σ² ~ truncated(Normal(0, 1); lower=0)
# Set intercept prior.
intercept ~ Normal(0, sqrt(3))
# Set spline prior
w_beat ~ MvNormal(zeros(length(basis_beat)), 1)
w_trend ~ MvNormal(zeros(length(basis_trend)), 1)
s_beat = Spline(basis_beat, w_beat)
s_trend = Spline(basis_trend, w_trend)
# Beat vector
# Instead of drawing 12 intervals, this shoud be the n_beats (Possibly drawn from a Poisson), but there may be a better way
beat_intervals ~ filldist(LogNormal(log(mean_beat_interval),
rel_var_beat_interval), 12)
beat_pos = [0; cumsum(beat_intervals)]
time_since_beat = get_time_since_event(x, beat_pos)
# Calculate all the mu terms.
mu = intercept .+ s_beat.(time_since_beat) .+ s_trend.(x)
for i in eachindex(y)
# Likelihood
y[i] ~ Normal(mu[i], σ²)
end
return (mu = mu, pred = y) # sampled with generated_quantities
end
full_spline_model = full_spline_regression(X10s, Y10s)
chains_full_spline_model = sample(full_spline_model, NUTS(), 100)
full_spline_model_post = full_spline_regression(X10s, Vector{Union{Missing, Float64}}(undef, length(X10s)))
samples_full_spline_model_post = generated_quantities(full_spline_model_post, chains_full_spline_model)
samples_full_spline_model_post_mu = gq2matrix(samples_full_spline_model_post, :mu)
plot_post_pred!(samples_full_spline_model_post_mu, X10s)
plot!(X10s, Y10s, lw = 1, color = :black)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment