Last active
July 7, 2022 10:07
-
-
Save shug3502/f1ac03c6ee401098c7271d66f55123c7 to your computer and use it in GitHub Desktop.
Bayesian inference for metaphase model of chromosome dynamics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## This script performs inference for the model of chromosome dynamics in metaphase | |
## described in Armond et al 2015 Plos Comp Biol. The four discrete hidden states (++,+-,-+,--) are | |
## marginalised out, as described in the methods of Harrison et al, 2021, bioRxiv. | |
## The likelihood is evaluated via the forward algorithm, and derivatives of the log likelihood | |
## are made available via automatic differentiation in Julia. These are used in an adaptive MCMC algorithm | |
## with a proposal known as Barker's method, and described in Livingstone and Zanella, 2022, JRSSB. | |
## | |
## Jonathan U. Harrison - 2022-07-07 | |
#################################################### | |
using Plots, Random, Distributions | |
using MCMCChains, Plots, StatsPlots, StatsBase | |
using LinearAlgebra, ForwardDiff | |
#see https://nextjournal.com/jbowles/the-mathematical-ideal-and-softmax-in-julia | |
#abstract exponentiation function, subtract max for numerical stability | |
_exp(x::AbstractVecOrMat) = exp.(x .- maximum(x)) | |
#softmax algorithm expects stablized eponentiated e | |
_sftmax(e::AbstractVecOrMat, d::Integer) = (e ./ sum(e, dims = d)) | |
# top level softmax function | |
function softmax(X::AbstractVecOrMat{T}, dim::Integer)::AbstractVecOrMat where T <: AbstractFloat | |
_sftmax(_exp(X), dim) | |
end | |
function compute_sequence_given_skeleton(skeleton,sigma0,T) | |
state_sequence = repeat([0.0 0.0],outer=T+1) | |
state_sequence[1,:] = sigma0 | |
for j in 1:2 | |
if length(skeleton[j])>1 | |
for i in 2:length(skeleton[j]) | |
mask = skeleton[j][i-1]:skeleton[j][i] | |
state_sequence[mask,j] = repeat([state_sequence[mask[1],j]],outer=length(mask)) | |
state_sequence[mask[end],j] *= -1 | |
end | |
end | |
end | |
return state_sequence | |
end | |
function logpriorjoint(th) | |
# tau, alpha, kappa, v_plus, v_minus, L = th | |
tau = Gamma(0.5,1) | |
alpha = truncated(Normal(0.01,0.1),0,Inf) | |
kappa = truncated(Normal(0.05,0.1),0,Inf) | |
v_plus = truncated(Normal(0.03,0.1),0,Inf) | |
v_minus = truncated(Normal(-0.03,0.1),-Inf,0) | |
L = truncated(Normal(0.790,0.119),0,Inf) | |
p_coh = Beta(45,5) | |
p_icoh = Beta(12,3) | |
theta_prior = [tau,alpha,kappa,v_plus,v_minus,L,p_coh,p_icoh] | |
return sum(logpdf.(theta_prior,th)) | |
end | |
function get_xi_and_f(y,theta::Vector{TT},angle,dt,x0) where {TT} | |
tau, alpha, kappa, v_plus, v_minus, L, p_coh, p_icoh = theta | |
q_coh = 1-p_coh | |
q_icoh = 1-p_icoh | |
P = Matrix{TT}(undef, 4, 4) | |
# P = zeros(eltype(theta),4,4) #transition matrix #see https://stackoverflow.com/questions/68485811/julia-roots-find-zero-with-forwarddiff-dual-type | |
P[:,1] = [p_icoh*p_icoh p_coh*q_coh p_coh*q_coh q_icoh*q_icoh] | |
P[:,2] = [p_icoh*q_icoh p_coh*p_coh q_coh*q_coh p_icoh*q_icoh] | |
P[:,3] = [p_icoh*q_icoh q_coh*q_coh p_coh*p_coh p_icoh*q_icoh] | |
P[:,4] = [q_icoh*q_icoh p_coh*q_coh p_coh*q_coh p_icoh*p_icoh] | |
tau *= 1000 #rescale precision | |
noise = sqrt(dt/tau) | |
get_force(state) = (state>0)*v_plus + (state<0)*v_minus | |
T = size(y,1) | |
f = Vector{TT}(undef,T) | |
xi = Matrix{TT}(undef,T,4) | |
eta = Matrix{TT}(undef,T,4) | |
possible_states = [[1.0, 1.0], [1.0, -1.0], [-1.0, 1.0], [-1.0, -1.0]] | |
xi0 = 0.25*ones(4) #assume initially all states equally likely | |
for j in 1:4 | |
eta[1,j] = pdf(MvNormal(x0,noise),y[1,:]) | |
end | |
for t in 2:T | |
for j in 1:4 | |
state = possible_states[j] | |
mu = [y[t-1,1] + dt*(-get_force(state[1]) - kappa*(y[t-1,1]-y[t-1,2]-L*cos(angle[t])) -alpha*y[t-1,1]), | |
y[t-1,2] + dt*(get_force(state[2]) - kappa*(y[t-1,2]-y[t-1,1]+L*cos(angle[t])) -alpha*y[t-1,2]) ] | |
eta[t,j] = pdf(MvNormal(mu,noise),y[t,:]) | |
end | |
end | |
for t in 1:T | |
if t == 1 | |
f[t] = dot(xi0'*P,eta[t,:]) | |
xi[t,:] = ((xi0'*P) .* eta[t,:]') ./ f[t] | |
else | |
f[t] = dot(xi[t-1,:]'*P,eta[t,:]) | |
xi[t,:] = ((xi[t-1,:]'*P) .* eta[t,:]') ./ f[t] | |
end | |
end | |
return xi, f | |
end | |
function marginalised_loglikelihood(y,theta::Vector{TT},angle,dt,x0) where {TT} | |
if all(sign.(theta) .== [1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0]) | |
_, f = get_xi_and_f(y,theta,angle,dt,x0) | |
if any(f .< 0) | |
return -Inf | |
else | |
log_lik = sum(log.(f)) | |
return log_lik | |
end | |
else | |
return -Inf | |
end | |
end | |
function backward_sample_states(xi::Array{TT},theta::Vector{TT}) where {TT} | |
T = size(xi,1) | |
sigma_sim = Vector{Int64}(undef,T) | |
conditional_state_probs = Vector{TT}(undef,4) | |
P_col = Vector{TT}(undef,4) | |
state_probs = xi[T,:] | |
tau, alpha, kappa, v_plus, v_minus, L, p_coh, p_icoh = theta | |
q_coh = 1-p_coh | |
q_icoh = 1-p_icoh | |
P = Matrix{TT}(undef, 4, 4) | |
# P = zeros(eltype(theta),4,4) #transition matrix #see https://stackoverflow.com/questions/68485811/julia-roots-find-zero-with-forwarddiff-dual-type | |
P[:,1] = [p_icoh*p_icoh p_coh*q_coh p_coh*q_coh q_icoh*q_icoh] | |
P[:,2] = [p_icoh*q_icoh p_coh*p_coh q_coh*q_coh p_icoh*q_icoh] | |
P[:,3] = [p_icoh*q_icoh q_coh*q_coh p_coh*p_coh p_icoh*q_icoh] | |
P[:,4] = [q_icoh*q_icoh p_coh*q_coh p_coh*q_coh p_icoh*p_icoh] | |
sigma_sim[T] = rand(Categorical(state_probs)) | |
for t in 1:(T-1) | |
frame = T+1-t | |
P_col = P[:,sigma_sim[frame]] | |
conditional_state_probs = log.(P_col) .+ log.(xi[frame-1,:]) | |
if (isinf(sum(conditional_state_probs))) | |
sigma_sim[frame-1] = rand(Categorical(exp.(conditional_state_probs)./sum(exp.(conditional_state_probs)))); | |
else | |
sigma_sim[frame-1] = rand(Categorical(softmax(conditional_state_probs,1))); | |
end | |
end | |
return sigma_sim | |
end | |
function rbarker(x,grad,sigma,diag_sd) | |
# x: current location (vector) | |
# grad: target log-posterior gradient (vector) | |
# sigma: proposal stepsize (scalar) | |
#diag_sd: standard deviation for each component (vector) | |
z = sigma.*rand(MvNormal(zeros(length(grad)),diagm(diag_sd))) | |
b = 2 .* (rand(length(grad)) .< 1 ./ (1 .+ exp.(-grad.*z))) .- 1 | |
return(x .+ z .* b) | |
end | |
## the log-acceptance rate is computed as follows | |
function log_q_ratio_barker(x,y,grad_x,grad_y) | |
# x: current location (vector) | |
# y: proposed location (vector) | |
# grad_x: target log-posterior gradient at x (vector) | |
# grad_y: target log-posterior gradient at y (vector) | |
beta1 = -grad_y .* (x .- y) | |
beta2 = -grad_x .* (y .- x) | |
A = sum(# compute acceptance with log_sum_exp trick for numerical stability | |
-(max.(beta1,0)+log1p.(exp.(-abs.(beta1))))+ | |
(max.(beta2,0)+log1p.(exp.(-abs.(beta2))))) | |
return A | |
end | |
#for adaption | |
function gamma(t,kappa=0.6) | |
return t^(-kappa) | |
end | |
############################################################### | |
########### | |
#simulate some data | |
#fix dynamic parameters | |
#can we reconstruct the skeleton? | |
# Set a seed for reproducibility. | |
Random.seed!(14); | |
dt=2 | |
# T = 50 | |
# skeleton = [[1,15,36,40,T+1],[1,13,33,41,T+1]] | |
T = 250 | |
skeleton = [[1,15,36,40,58,61,69,84,120,160,224,246,T+1],[1,13,33,41,59,65,76,110,119,165,180,186,220,244,T+1]] | |
sigma0 = [1.0,-1.0] | |
state_sequence = compute_sequence_given_skeleton(skeleton,sigma0,T) | |
tau = 500 | |
alpha = 0.02 | |
kappa = 0.01 | |
v_plus = 0.03 | |
v_minus = -0.05 | |
L = 0.8 | |
theta_d = [tau/1000, alpha, kappa, v_plus, v_minus, L] #NB rescaling tau | |
get_force(state) = (state>0)*v_plus + (state<0)*v_minus | |
#generate data | |
u0 = [1.1 0.1] | |
angle = repeat([0.0],outer=T) | |
y = repeat([0.0 0.0],outer=T) | |
y[1,:] = u0 | |
for t in 2:T | |
y[t,1] = y[t-1,1] + dt*(-get_force(state_sequence[t,1]) - kappa*(y[t-1,1]-y[t-1,2]-L*cos(angle[t])) -alpha*y[t-1,1]) + sqrt(dt/tau)*randn() | |
y[t,2] = y[t-1,2] + dt*(get_force(state_sequence[t,2]) - kappa*(y[t-1,2]-y[t-1,1]+L*cos(angle[t])) -alpha*y[t-1,2]) + sqrt(dt/tau)*randn() | |
# y[t,:] .+= sqrt(dt/tau)*randn(2) | |
end | |
p1 = plot(state_sequence[:,1]) | |
plot!(p1,state_sequence[:,2]) | |
p2 = plot(y) | |
plot(p1,p2) | |
############# | |
grad_ll(th) = ForwardDiff.gradient(x -> marginalised_loglikelihood(y, x, angle, dt, y[1,:]), th) | |
tau = Gamma(0.5,1) | |
alpha = truncated(Normal(0.01,0.1),0,Inf) | |
kappa = truncated(Normal(0.05,0.1),0,Inf) | |
v_plus = truncated(Normal(0.03,0.1),0,Inf) | |
v_minus = truncated(Normal(-0.03,0.1),-Inf,0) | |
L = truncated(Normal(0.790,0.119),0,Inf) | |
p_coh = Beta(45,5) | |
p_icoh = Beta(12,3) | |
theta_prior = [tau,alpha,kappa,v_plus,v_minus,L,p_coh,p_icoh] | |
@time begin | |
Random.seed!(123); | |
x0 = y[1,:] | |
# theta_h = [0.95,0.7] | |
# theta_d = [0.5,0.02,0.01,0.03,-0.05,0.8] | |
# theta = vcat(theta_d, theta_h) | |
theta = rand.(theta_prior) | |
# theta_d = theta[1:6] | |
# theta_h = theta[7:8] | |
theta_store = [] | |
num_iter = 5000 | |
target_ap = 0.4 | |
print_frequency = num_iter/20 | |
num_accepted = 0 | |
iter = 0 | |
theta_curr = deepcopy(theta) | |
#initialise for adaption | |
sigma_vec = Vector{Float64}(undef,num_iter) | |
sigma_t = Matrix{Float64}(undef,num_iter,length(theta_curr)) | |
## params for adaption | |
sigma = 0.001*2.4/sqrt(length(theta_curr)^(1/3)) | |
diag_var = ones(length(theta_curr)) | |
x_means = deepcopy(theta_curr) | |
sigma_vec[1] = sigma | |
sigma_t[1,:] = diag_var | |
while iter < num_iter | |
grad_curr = grad_ll(theta_curr) | |
theta_star = rbarker(theta_curr,grad_curr,sigma,sqrt.(diag_var)) #RW_proposal(theta_curr,0.01,I(length(theta))) | |
grad_star = grad_ll(theta_star) | |
#now evaluate acceptance ratio | |
if mod(iter,print_frequency) == 0 | |
println("Accepted: ", num_accepted, " Proposed: ", iter) | |
println("proposed: ",marginalised_loglikelihood(y,theta_star,angle,dt,x0), " current: ",marginalised_loglikelihood(y,theta_curr,angle,dt,x0),"\n") | |
end | |
log_ratio = marginalised_loglikelihood(y,theta_star,angle,dt,x0) - | |
marginalised_loglikelihood(y,theta_curr,angle,dt,x0) + | |
logpriorjoint(theta_star) - logpriorjoint(theta_curr) + | |
log_q_ratio_barker(theta_curr,theta_star,grad_curr,grad_star) | |
u = rand() | |
if log_ratio > log(u) | |
#accept | |
global num_accepted += 1 | |
global theta_curr = copy(theta_star) | |
append!(theta_store,[theta_star]) | |
else | |
append!(theta_store,[theta_curr]) | |
end #o/w reject | |
global iter += 1 | |
#adaption | |
# 1- adapt global scale | |
ap = min(1,exp(log_ratio)) #acceptance probability | |
log_sigma_2 = log(sigma^2)+gamma(1+iter)*(ap-target_ap) | |
global sigma = sqrt(exp(log_sigma_2)) | |
# 2- adapt means | |
x_means .= x_means .+ gamma(1+iter) .* (theta_curr .- x_means) | |
# 3- adapt diagonal covariance | |
diag_var .= diag_var .+ gamma(1+iter) .* ((theta_curr .- x_means).^2 .- diag_var) | |
# store adaptation parameters | |
sigma_vec[iter] = sigma | |
sigma_t[iter,:] = diag_var | |
if iter > 10^8 | |
break | |
end | |
end | |
end | |
plot(log.(sigma_t),layout=8) | |
thinning = 10 | |
p_arr = [plot(1:thinning:num_iter,[theta_store[jj][ii] for jj in 1:thinning:length(theta_store)],legend=false) for ii in 1:8] | |
plot(p_arr...) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment