Skip to content

Instantly share code, notes, and snippets.

@shug3502
Last active July 7, 2022 10:07
Show Gist options
  • Save shug3502/f1ac03c6ee401098c7271d66f55123c7 to your computer and use it in GitHub Desktop.
Save shug3502/f1ac03c6ee401098c7271d66f55123c7 to your computer and use it in GitHub Desktop.
Bayesian inference for metaphase model of chromosome dynamics
## This script performs inference for the model of chromosome dynamics in metaphase
## described in Armond et al 2015 Plos Comp Biol. The four discrete hidden states (++,+-,-+,--) are
## marginalised out, as described in the methods of Harrison et al, 2021, bioRxiv.
## The likelihood is evaluated via the forward algorithm, and derivatives of the log likelihood
## are made available via automatic differentiation in Julia. These are used in an adaptive MCMC algorithm
## with a proposal known as Barker's method, and described in Livingstone and Zanella, 2022, JRSSB.
##
## Jonathan U. Harrison - 2022-07-07
####################################################
using Plots, Random, Distributions
using MCMCChains, Plots, StatsPlots, StatsBase
using LinearAlgebra, ForwardDiff
#see https://nextjournal.com/jbowles/the-mathematical-ideal-and-softmax-in-julia
#abstract exponentiation function, subtract max for numerical stability
_exp(x::AbstractVecOrMat) = exp.(x .- maximum(x))
#softmax algorithm expects stablized eponentiated e
_sftmax(e::AbstractVecOrMat, d::Integer) = (e ./ sum(e, dims = d))
# top level softmax function
function softmax(X::AbstractVecOrMat{T}, dim::Integer)::AbstractVecOrMat where T <: AbstractFloat
_sftmax(_exp(X), dim)
end
function compute_sequence_given_skeleton(skeleton,sigma0,T)
state_sequence = repeat([0.0 0.0],outer=T+1)
state_sequence[1,:] = sigma0
for j in 1:2
if length(skeleton[j])>1
for i in 2:length(skeleton[j])
mask = skeleton[j][i-1]:skeleton[j][i]
state_sequence[mask,j] = repeat([state_sequence[mask[1],j]],outer=length(mask))
state_sequence[mask[end],j] *= -1
end
end
end
return state_sequence
end
function logpriorjoint(th)
# tau, alpha, kappa, v_plus, v_minus, L = th
tau = Gamma(0.5,1)
alpha = truncated(Normal(0.01,0.1),0,Inf)
kappa = truncated(Normal(0.05,0.1),0,Inf)
v_plus = truncated(Normal(0.03,0.1),0,Inf)
v_minus = truncated(Normal(-0.03,0.1),-Inf,0)
L = truncated(Normal(0.790,0.119),0,Inf)
p_coh = Beta(45,5)
p_icoh = Beta(12,3)
theta_prior = [tau,alpha,kappa,v_plus,v_minus,L,p_coh,p_icoh]
return sum(logpdf.(theta_prior,th))
end
function get_xi_and_f(y,theta::Vector{TT},angle,dt,x0) where {TT}
tau, alpha, kappa, v_plus, v_minus, L, p_coh, p_icoh = theta
q_coh = 1-p_coh
q_icoh = 1-p_icoh
P = Matrix{TT}(undef, 4, 4)
# P = zeros(eltype(theta),4,4) #transition matrix #see https://stackoverflow.com/questions/68485811/julia-roots-find-zero-with-forwarddiff-dual-type
P[:,1] = [p_icoh*p_icoh p_coh*q_coh p_coh*q_coh q_icoh*q_icoh]
P[:,2] = [p_icoh*q_icoh p_coh*p_coh q_coh*q_coh p_icoh*q_icoh]
P[:,3] = [p_icoh*q_icoh q_coh*q_coh p_coh*p_coh p_icoh*q_icoh]
P[:,4] = [q_icoh*q_icoh p_coh*q_coh p_coh*q_coh p_icoh*p_icoh]
tau *= 1000 #rescale precision
noise = sqrt(dt/tau)
get_force(state) = (state>0)*v_plus + (state<0)*v_minus
T = size(y,1)
f = Vector{TT}(undef,T)
xi = Matrix{TT}(undef,T,4)
eta = Matrix{TT}(undef,T,4)
possible_states = [[1.0, 1.0], [1.0, -1.0], [-1.0, 1.0], [-1.0, -1.0]]
xi0 = 0.25*ones(4) #assume initially all states equally likely
for j in 1:4
eta[1,j] = pdf(MvNormal(x0,noise),y[1,:])
end
for t in 2:T
for j in 1:4
state = possible_states[j]
mu = [y[t-1,1] + dt*(-get_force(state[1]) - kappa*(y[t-1,1]-y[t-1,2]-L*cos(angle[t])) -alpha*y[t-1,1]),
y[t-1,2] + dt*(get_force(state[2]) - kappa*(y[t-1,2]-y[t-1,1]+L*cos(angle[t])) -alpha*y[t-1,2]) ]
eta[t,j] = pdf(MvNormal(mu,noise),y[t,:])
end
end
for t in 1:T
if t == 1
f[t] = dot(xi0'*P,eta[t,:])
xi[t,:] = ((xi0'*P) .* eta[t,:]') ./ f[t]
else
f[t] = dot(xi[t-1,:]'*P,eta[t,:])
xi[t,:] = ((xi[t-1,:]'*P) .* eta[t,:]') ./ f[t]
end
end
return xi, f
end
function marginalised_loglikelihood(y,theta::Vector{TT},angle,dt,x0) where {TT}
if all(sign.(theta) .== [1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0])
_, f = get_xi_and_f(y,theta,angle,dt,x0)
if any(f .< 0)
return -Inf
else
log_lik = sum(log.(f))
return log_lik
end
else
return -Inf
end
end
function backward_sample_states(xi::Array{TT},theta::Vector{TT}) where {TT}
T = size(xi,1)
sigma_sim = Vector{Int64}(undef,T)
conditional_state_probs = Vector{TT}(undef,4)
P_col = Vector{TT}(undef,4)
state_probs = xi[T,:]
tau, alpha, kappa, v_plus, v_minus, L, p_coh, p_icoh = theta
q_coh = 1-p_coh
q_icoh = 1-p_icoh
P = Matrix{TT}(undef, 4, 4)
# P = zeros(eltype(theta),4,4) #transition matrix #see https://stackoverflow.com/questions/68485811/julia-roots-find-zero-with-forwarddiff-dual-type
P[:,1] = [p_icoh*p_icoh p_coh*q_coh p_coh*q_coh q_icoh*q_icoh]
P[:,2] = [p_icoh*q_icoh p_coh*p_coh q_coh*q_coh p_icoh*q_icoh]
P[:,3] = [p_icoh*q_icoh q_coh*q_coh p_coh*p_coh p_icoh*q_icoh]
P[:,4] = [q_icoh*q_icoh p_coh*q_coh p_coh*q_coh p_icoh*p_icoh]
sigma_sim[T] = rand(Categorical(state_probs))
for t in 1:(T-1)
frame = T+1-t
P_col = P[:,sigma_sim[frame]]
conditional_state_probs = log.(P_col) .+ log.(xi[frame-1,:])
if (isinf(sum(conditional_state_probs)))
sigma_sim[frame-1] = rand(Categorical(exp.(conditional_state_probs)./sum(exp.(conditional_state_probs))));
else
sigma_sim[frame-1] = rand(Categorical(softmax(conditional_state_probs,1)));
end
end
return sigma_sim
end
function rbarker(x,grad,sigma,diag_sd)
# x: current location (vector)
# grad: target log-posterior gradient (vector)
# sigma: proposal stepsize (scalar)
#diag_sd: standard deviation for each component (vector)
z = sigma.*rand(MvNormal(zeros(length(grad)),diagm(diag_sd)))
b = 2 .* (rand(length(grad)) .< 1 ./ (1 .+ exp.(-grad.*z))) .- 1
return(x .+ z .* b)
end
## the log-acceptance rate is computed as follows
function log_q_ratio_barker(x,y,grad_x,grad_y)
# x: current location (vector)
# y: proposed location (vector)
# grad_x: target log-posterior gradient at x (vector)
# grad_y: target log-posterior gradient at y (vector)
beta1 = -grad_y .* (x .- y)
beta2 = -grad_x .* (y .- x)
A = sum(# compute acceptance with log_sum_exp trick for numerical stability
-(max.(beta1,0)+log1p.(exp.(-abs.(beta1))))+
(max.(beta2,0)+log1p.(exp.(-abs.(beta2)))))
return A
end
#for adaption
function gamma(t,kappa=0.6)
return t^(-kappa)
end
###############################################################
###########
#simulate some data
#fix dynamic parameters
#can we reconstruct the skeleton?
# Set a seed for reproducibility.
Random.seed!(14);
dt=2
# T = 50
# skeleton = [[1,15,36,40,T+1],[1,13,33,41,T+1]]
T = 250
skeleton = [[1,15,36,40,58,61,69,84,120,160,224,246,T+1],[1,13,33,41,59,65,76,110,119,165,180,186,220,244,T+1]]
sigma0 = [1.0,-1.0]
state_sequence = compute_sequence_given_skeleton(skeleton,sigma0,T)
tau = 500
alpha = 0.02
kappa = 0.01
v_plus = 0.03
v_minus = -0.05
L = 0.8
theta_d = [tau/1000, alpha, kappa, v_plus, v_minus, L] #NB rescaling tau
get_force(state) = (state>0)*v_plus + (state<0)*v_minus
#generate data
u0 = [1.1 0.1]
angle = repeat([0.0],outer=T)
y = repeat([0.0 0.0],outer=T)
y[1,:] = u0
for t in 2:T
y[t,1] = y[t-1,1] + dt*(-get_force(state_sequence[t,1]) - kappa*(y[t-1,1]-y[t-1,2]-L*cos(angle[t])) -alpha*y[t-1,1]) + sqrt(dt/tau)*randn()
y[t,2] = y[t-1,2] + dt*(get_force(state_sequence[t,2]) - kappa*(y[t-1,2]-y[t-1,1]+L*cos(angle[t])) -alpha*y[t-1,2]) + sqrt(dt/tau)*randn()
# y[t,:] .+= sqrt(dt/tau)*randn(2)
end
p1 = plot(state_sequence[:,1])
plot!(p1,state_sequence[:,2])
p2 = plot(y)
plot(p1,p2)
#############
grad_ll(th) = ForwardDiff.gradient(x -> marginalised_loglikelihood(y, x, angle, dt, y[1,:]), th)
tau = Gamma(0.5,1)
alpha = truncated(Normal(0.01,0.1),0,Inf)
kappa = truncated(Normal(0.05,0.1),0,Inf)
v_plus = truncated(Normal(0.03,0.1),0,Inf)
v_minus = truncated(Normal(-0.03,0.1),-Inf,0)
L = truncated(Normal(0.790,0.119),0,Inf)
p_coh = Beta(45,5)
p_icoh = Beta(12,3)
theta_prior = [tau,alpha,kappa,v_plus,v_minus,L,p_coh,p_icoh]
@time begin
Random.seed!(123);
x0 = y[1,:]
# theta_h = [0.95,0.7]
# theta_d = [0.5,0.02,0.01,0.03,-0.05,0.8]
# theta = vcat(theta_d, theta_h)
theta = rand.(theta_prior)
# theta_d = theta[1:6]
# theta_h = theta[7:8]
theta_store = []
num_iter = 5000
target_ap = 0.4
print_frequency = num_iter/20
num_accepted = 0
iter = 0
theta_curr = deepcopy(theta)
#initialise for adaption
sigma_vec = Vector{Float64}(undef,num_iter)
sigma_t = Matrix{Float64}(undef,num_iter,length(theta_curr))
## params for adaption
sigma = 0.001*2.4/sqrt(length(theta_curr)^(1/3))
diag_var = ones(length(theta_curr))
x_means = deepcopy(theta_curr)
sigma_vec[1] = sigma
sigma_t[1,:] = diag_var
while iter < num_iter
grad_curr = grad_ll(theta_curr)
theta_star = rbarker(theta_curr,grad_curr,sigma,sqrt.(diag_var)) #RW_proposal(theta_curr,0.01,I(length(theta)))
grad_star = grad_ll(theta_star)
#now evaluate acceptance ratio
if mod(iter,print_frequency) == 0
println("Accepted: ", num_accepted, " Proposed: ", iter)
println("proposed: ",marginalised_loglikelihood(y,theta_star,angle,dt,x0), " current: ",marginalised_loglikelihood(y,theta_curr,angle,dt,x0),"\n")
end
log_ratio = marginalised_loglikelihood(y,theta_star,angle,dt,x0) -
marginalised_loglikelihood(y,theta_curr,angle,dt,x0) +
logpriorjoint(theta_star) - logpriorjoint(theta_curr) +
log_q_ratio_barker(theta_curr,theta_star,grad_curr,grad_star)
u = rand()
if log_ratio > log(u)
#accept
global num_accepted += 1
global theta_curr = copy(theta_star)
append!(theta_store,[theta_star])
else
append!(theta_store,[theta_curr])
end #o/w reject
global iter += 1
#adaption
# 1- adapt global scale
ap = min(1,exp(log_ratio)) #acceptance probability
log_sigma_2 = log(sigma^2)+gamma(1+iter)*(ap-target_ap)
global sigma = sqrt(exp(log_sigma_2))
# 2- adapt means
x_means .= x_means .+ gamma(1+iter) .* (theta_curr .- x_means)
# 3- adapt diagonal covariance
diag_var .= diag_var .+ gamma(1+iter) .* ((theta_curr .- x_means).^2 .- diag_var)
# store adaptation parameters
sigma_vec[iter] = sigma
sigma_t[iter,:] = diag_var
if iter > 10^8
break
end
end
end
plot(log.(sigma_t),layout=8)
thinning = 10
p_arr = [plot(1:thinning:num_iter,[theta_store[jj][ii] for jj in 1:thinning:length(theta_store)],legend=false) for ii in 1:8]
plot(p_arr...)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment