Skip to content

Instantly share code, notes, and snippets.

@reedacartwright
Last active December 13, 2017 22:15
Show Gist options
  • Save reedacartwright/7b5339862c9abd388e10b101bfed6f07 to your computer and use it in GitHub Desktop.
Save reedacartwright/7b5339862c9abd388e10b101bfed6f07 to your computer and use it in GitHub Desktop.
HMM to estimate coverage
#!/usr/bin/Rscript
# Authors: David Winter
# Joanna Malukiewicz
# Reed A. Cartwright
library(HiddenMarkov)
library(data.table) #makes reading a data table faster than using R data frame functions
library(jsonlite)
cat("Started at ", as.character(Sys.time()), "\n")
## Collect arguments
args = commandArgs(TRUE)
# correct bug in HiddenMarkov package
forwardback = function (x, Pi, delta, distn, pm, pn = NULL, fortran=TRUE) {
m <- nrow(Pi)
n <- length(x)
dfunc <- HiddenMarkov:::makedensity(distn)
prob <- matrix(as.double(0), nrow=n, ncol=m)
for (k in 1:m)
prob[,k] <- dfunc(x=x, HiddenMarkov:::getj(pm, k), pn, log=TRUE)
probmax = apply(prob,1,max)
prob = exp(prob-probmax)
# forward probabilities alpha_ij
phi <- as.double(delta)
logalpha <- matrix(as.double(rep(0, m*n)), nrow=n)
lscale <- as.double(0)
if (fortran!=TRUE){
# loop1 using R code
for (i in 1:n){
if (i > 1) phi <- phi %*% Pi
phi <- phi*prob[i,]
sumphi <- sum(phi)
phi <- phi/sumphi
lscale <- lscale + log(sumphi)
logalpha[i,] <- log(phi) + lscale
}
LL <- lscale
} else{
if (!is.double(Pi)) stop("Pi is not double precision")
if (!is.double(prob)) stop("prob is not double precision")
memory0 <- rep(as.double(0), m)
loop1 <- .Fortran("loop1", m, n, phi, prob, Pi, logalpha,
lscale, memory0, PACKAGE="HiddenMarkov")
logalpha <- loop1[[6]]
LL <- loop1[[7]]
}
# backward probabilities beta_ij
logbeta <- matrix(as.double(rep(0, m*n)), nrow=n)
phi <- as.double(rep(1/m, m))
lscale <- as.double(log(m))
if (fortran!=TRUE){
# loop2 using R code
for (i in seq(n-1, 1, -1)){
phi <- Pi %*% (prob[i+1,]*phi)
logbeta[i,] <- log(phi) + lscale
sumphi <- sum(phi)
phi <- phi/sumphi
lscale <- lscale + log(sumphi)
}
} else{
memory0 <- rep(as.double(0), m)
loop2 <- .Fortran("loop2", m, n, phi, prob, Pi, logbeta,
lscale, memory0, PACKAGE="HiddenMarkov")
logbeta <- loop2[[6]]
}
logalpha = logalpha + cumsum(probmax)
logbeta = logbeta + rev(cumsum(rev(c(probmax[-1],0))))
LL = LL + sum(probmax)
return(list(logalpha=logalpha, logbeta=logbeta, LL=LL))
}
assignInNamespace("forwardback", forwardback, "HiddenMarkov")
Mstep.nbinom = function(x, cond, pm, pn) {
w = cond$u
mu = c()
size = c()
for(i in seq_len(ncol(w))) {
ww = w[,i]
A = weighted.mean(x, ww)
r = pm$size[i]
# function for calculating the gradient
g = function(r) {
if(r <= 0) {
return(NA)
}
u = digamma(x+r)-digamma(r)
u = u-log1p(A/r)
as.vector(ww %*% u)
}
# function for calculating the hessian
h = function(r) {
if(r <= 0) {
return(NA)
}
v = trigamma(x+r)-trigamma(r)
v = v + A/r/(A+r)
as.vector(ww %*% v)
}
# Find the root of the gradient
o = uniroot(g, c(r*0.1,r*1.9),extendInt="yes")
size = c(size,o$root)
mu = c(mu,A)
}
# Construct return value
el = list(mu=mu,size=size)
cat("params =\n")
print(do.call(cbind,el))
el
}
# validate parameters
num = as.numeric(args[1])
if(is.na(num)) {
stop("invalid column id")
}
if(!file.exists(args[2])) {
stop("input file does not exist")
}
cmd=paste0("cut -f", num, " ", args[2])
data=fread(cmd)
x=as.integer(data$V1)
# # simulate an HMM
#Pi = matrix(c(0.95, 0.05, 0.05, 0.95),2)
#delta = c(0.26, 0.74)
#pm = list(mu=c(3,10), size=c(1.5,2))
# mod = dthmm(NULL, Pi, delta, "nbinom", pm=pm, discrete=TRUE)
# sim_hmm = simulate(mod,1e6)
# x = sim_hmm$x
Pi = matrix(c(0.99, 0.01, 0.01, 0.99),2)
delta = c(0.5, 0.5)
pm = list(mu=c(2,10), size=c(100,100))
Pi = matrix(c(0.98, 0.01, 0.01, 0.01, 0.98, 0.01, 0.01, 0.01, 0.98),3)
delta = c(1/3, 1/3, 1/3)
pm = list(mu=c(0.1,2,20), size=c(100,100,100))
# Define HMM Model
mod = dthmm(x, Pi, delta, "nbinom", pm=pm, discrete=TRUE)
# Estimate Parameters
controls = bwcontrol(posdiff = FALSE)
res = BaumWelch(mod, controls)
# Estimate site Components
site_components = Viterbi(res)
# Construct output object
output = list(
LL = res$LL,
iter = res$iter,
Pi = res$Pi,
pm = res$pm,
delta = res$delta,
viterbi = site_components
)
output = toJSON(output,pretty=TRUE)
cat(output, file=args[3])
cat("Ended at ", as.character(Sys.time()), "\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment