-
-
Save JohannesNE/ce399b6d7422621b37e54a03711e2bce to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| using Turing | |
| using DataFrames | |
| using CSV | |
| using Tidier | |
| using Plots | |
| using StatsPlots | |
| using BSplines | |
| sample_data_path = "https://gist.githubusercontent.com/JohannesNE/df539760195619c642b0ea43b07c6ca3/raw/bb72096912b07b5221bd4a2202eac0ec7c9121a4/acc_apnea.csv" | |
| acc = DataFrame(CSV.File(download(sample_data_path))) | |
| acc10s = @filter(acc, sec < 10) | |
| Y10s = @pull(acc10s, z) | |
| X10s = @pull(acc10s, sec) | |
| # generated_quantities returns an array of tuples. This can be unpacked to a matrix | |
| function gq2matrix(gq, selector) | |
| reduce(vcat, getindex(samp, selector)' for samp in gq[:,1]) | |
| end | |
| function plot_post_pred!(pred_draws::Array, x, n = 100) | |
| plot() | |
| for i in 1:n | |
| plot!(x, pred_draws[i,:], legend=false, alpha = 0.2) | |
| end | |
| epred_mean = mean.(eachcol(pred_draws)) | |
| plot!(x, epred_mean, lw = 3, color = :blue) | |
| current() | |
| end | |
| function get_time_since_event(time_vec, event_vec) | |
| ann_index = zeros(eltype(event_vec), length(time_vec)) | |
| i_ann::Int64 = 1 | |
| for i in eachindex(time_vec) | |
| if i_ann < length(event_vec) && time_vec[i] > event_vec[i_ann + 1] | |
| i_ann += 1 | |
| end | |
| ann_index[i] = time_vec[i] - event_vec[i_ann] | |
| end | |
| ann_index | |
| end | |
| # Spline setup | |
| num_knots_beat = 30 | |
| num_knots_trend = 20 | |
| # knots_list = quantile(X1s, range(0, 1; length=num_knots)) | |
| knots_list_beat = range(0, 1; length=num_knots_beat) | |
| knots_list_trend = range(0, last(X10s); length=num_knots_trend) | |
| basis_beat = BSplineBasis(3, knots_list_beat) | |
| basis_trend = BSplineBasis(3, knots_list_trend) | |
| @model function full_spline_regression(x, y) | |
| obs_len_s = last(x) - first(x) | |
| # Set prior for beat interval | |
| mean_beat_interval ~ LogNormal(log(1), 0.4) | |
| rel_var_beat_interval ~ Exponential(0.1) | |
| n_beats ~ Poisson((1/mean_beat_interval) * obs_len_s) # Not currently used | |
| # Set variance prior. | |
| σ² ~ truncated(Normal(0, 1); lower=0) | |
| # Set intercept prior. | |
| intercept ~ Normal(0, sqrt(3)) | |
| # Set spline prior | |
| w_beat ~ MvNormal(zeros(length(basis_beat)), 1) | |
| w_trend ~ MvNormal(zeros(length(basis_trend)), 1) | |
| s_beat = Spline(basis_beat, w_beat) | |
| s_trend = Spline(basis_trend, w_trend) | |
| # Beat vector | |
| # Instead of drawing 12 intervals, this shoud be the n_beats (Possibly drawn from a Poisson), but there may be a better way | |
| beat_intervals ~ filldist(LogNormal(log(mean_beat_interval), | |
| rel_var_beat_interval), 12) | |
| beat_pos = [0; cumsum(beat_intervals)] | |
| time_since_beat = get_time_since_event(x, beat_pos) | |
| # Calculate all the mu terms. | |
| mu = intercept .+ s_beat.(time_since_beat) .+ s_trend.(x) | |
| for i in eachindex(y) | |
| # Likelihood | |
| y[i] ~ Normal(mu[i], σ²) | |
| end | |
| return (mu = mu, pred = y) # sampled with generated_quantities | |
| end | |
| full_spline_model = full_spline_regression(X10s, Y10s) | |
| chains_full_spline_model = sample(full_spline_model, NUTS(), 100) | |
| full_spline_model_post = full_spline_regression(X10s, Vector{Union{Missing, Float64}}(undef, length(X10s))) | |
| samples_full_spline_model_post = generated_quantities(full_spline_model_post, chains_full_spline_model) | |
| samples_full_spline_model_post_mu = gq2matrix(samples_full_spline_model_post, :mu) | |
| plot_post_pred!(samples_full_spline_model_post_mu, X10s) | |
| plot!(X10s, Y10s, lw = 1, color = :black) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment