Skip to content

Instantly share code, notes, and snippets.

@willtownes
Created February 16, 2017 04:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save willtownes/fec2de381fbd6fc357a0d99b6e234dfd to your computer and use it in GitHub Desktop.
Save willtownes/fec2de381fbd6fc357a0d99b6e234dfd to your computer and use it in GitHub Desktop.
Adaptive Rejection Sampling without Derivatives
### Adaptive Rejection Sampling
# by Will Townes
rexp_trunc<-function(n,slope=-1,lo=0,hi=Inf){
#draw n samples from the truncated exponential distribution
#the distribution is proportional to exp(slope*x)
#default is standard exponential
#slope cannot equal zero
#lo is lower truncation point, can be -Inf
#hi is upper truncation point, can be +Inf
u<-runif(n)
if(lo== -Inf){
stopifnot(slope>0 && hi<Inf)
return(hi+log(u)/slope)
} else if(hi==Inf){
stopifnot(slope<0 && lo> -Inf)
return(lo+log(u)/slope)
} else {
stopifnot(slope != 0)
}
lo+log1p(u*expm1(slope*(hi-lo)))/slope
}
ars_wt_formula<-function(a,b,c,d,numeric_zero=1e-8){
#compute formula for normalization constant of a sub-interval
#(1/a)exp(b)(exp(ad)-exp(ac))
#require that c <= d
#a is slope of line, b is intercept
#c, d are left, right endpoints of the interval
dmc<-d-c
stopifnot(dmc>=0)
if(abs(a)<numeric_zero){ return(exp(b)*dmc)}
(exp(b)/a)*(exp(a*d)-exp(a*c)) #improve numeric stability?
}
calc_wts_inner<-function(j,xpts,xstar,a,b){
#calculate weights for possibly two sub-intervals of interval j
#note that edge sub-intervals have weight zero
#prevents sampling sub-intervals outside range
#xstar,a,b must have same length
#xpts must have length one more than others
#xpts are boundaries of intervals
#xstar indicates "break point" within each interval
#a indicates slopes of interpolating lines within intervals
#b indicates intercepts of interpolating lines
wtj<-c(0,0)
if(xstar[j]==xpts[j]){
wtj[2]<-ars_wt_formula(a[j+1],b[j+1],xstar[j],xpts[j+1])
} else if(xstar[j]==xpts[j+1]){
wtj[1]<-ars_wt_formula(a[j-1],b[j-1],xpts[j],xstar[j])
} else {
wtj[1]<-ars_wt_formula(a[j-1],b[j-1],xpts[j],xstar[j])
wtj[2]<-ars_wt_formula(a[j+1],b[j+1],xstar[j],xpts[j+1])
}
wtj
}
calc_wts<-function(xpts,xstar,a,b,lo,hi,idx=seq_along(a)){
nIvl<-length(idx)
#stopifnot(nIvl>1)
w<-vector("numeric",2*nIvl+2) #indexed by j, all sub-intervals
w[1]<-ars_wt_formula(a[1],b[1],lo,xpts[1]) #left interval, lo can be -Inf
w[2*nIvl+2]<-ars_wt_formula(a[nIvl],b[nIvl],xpts[nIvl+1],hi) #right int., hi can be +Inf
w[2:(2*nIvl+1)]<-unlist(lapply(idx,calc_wts_inner,xpts,xstar,a,b))
#un-normalized weights
w
}
subinterval_to_interval<-function(j){floor(j/2)}
get_xstar<-function(xpts,a,b,idx=seq_along(a)){
#provide grid points xpts, slopes "a" and intercepts "b" for each line segment
#returns the breakpoints for the envelope function "xstar"
nIvl<-length(idx)
xstar<-xpts[idx] #initialize, edge case on left
xstar[nIvl]<-xpts[nIvl+1] #edge case on right
if(nIvl>2) xstar[2:(nIvl-1)]<- -diff(b,2)/diff(a,2)
#handle annoying edge-cases that occur for almost-linear regions of function
#due to rounding errors, can lead to negative weights if not addressed
too_low<-which(xstar < xpts[idx])
xstar[too_low]<-xpts[too_low]
too_hi<- which(xstar > xpts[idx+1])
xstar[too_hi]<-xpts[too_hi+1]
xstar
}
ars<-function(func,nSample=1,xpts,lo=0,hi=1,logscale=TRUE,verbose=TRUE){
#sample nSample times from univariate function f
#f(x) must be log-concave.
#If logscale==TRUE, assume log(f(x)) is provided instead of f(x)
#xpts are a grid of points to construct envelope function, must be >= 3 points!
#xpts must be in region of positive probability for f(x)
#lo,hi are lower,upper bounds of domain of integration of f, may be -Inf or +Inf
h<-if(logscale) func else function(x){log(func(x))}
xpts<-sort(xpts)
nIvl<-length(xpts)-1
ypts<-h(xpts)
stopifnot(all(ypts>-Inf))
a<-diff(ypts)/diff(xpts)
idx<-seq.int(nIvl) #index intervals by left endpoint
#one fewer intervals than points
b<-ypts[idx]-a*xpts[idx]
#compute "breakpoints" within each interval
xstar<-get_xstar(xpts,a,b,idx)
#handle problem values when outer intervals have infinities
#if(is.nan(xstar[2])) xstar[2]<-xpts[2]
#if(is.nan(xstar[nIvl-1])) xstar[nIvl-1]<-xpts[nIvl]
#compute weights for each sub-interval
w<-calc_wts(xpts,xstar,a,b,lo,hi,idx)
#each element of w is a sub interval
#w has length 2*nIvl+2 (2 sub-intervals per interval, and a left and right outside intervals)
#w<-w/sum(w) #normalization
res<-rep(NA,nSample)
nRes<-0
while(nRes < nSample){
#choose index from multinomial probs
j<-sample.int(2*nIvl+2,1,prob=w) #index of a subinterval
i<-subinterval_to_interval(j) #between zero and nIvl+1, inclusive
is_right<- as.logical(j%%2)
if(is_right){
#c for "current"
c_slope<-a[i+1]; c_hi<-xpts[i+1]; c_icpt<-b[i+1]
c_lo<-if(i>0){ xstar[i] }else{ lo }
} else { #case of left sub-interval
c_slope<-a[i-1]; c_lo<-xpts[i]; c_icpt<-b[i-1]
c_hi<-if(i<=nIvl){ c_hi<-xstar[i] }else{ hi }
}
x<-rexp_trunc(1,slope=c_slope,lo=c_lo,hi=c_hi)
#x is a valid sample from the upper envelope function
#next, do the accept/reject step
hx<-h(x)
#adding rexp(1) equiv to subtracting log(uniform(0,1))
accpt<- (c_slope*x+c_icpt <= hx + rexp(1))
if(accpt){ #accepted sample
nRes<-nRes+1
res[nRes]<-x
} else { #rejected sample
#print("rejected!")
#insert x into xpts and update all statistics
nIvl<-nIvl+1; idx<-append(idx,nIvl) #max possible j is (nIvl-1)
xpts<-append(xpts,x,i) #preserves ordering
ypts<-append(ypts,hx,i)
a<-append(a,NA,i); b<-append(b,NA,i)
if(i<nIvl){ #includes possibly outer left interval i=0
a[i+1]<-(ypts[i+2]-ypts[i+1])/(xpts[i+2]-xpts[i+1])
b[i+1]<-ypts[i+1]-a[i+1]*xpts[i+1]
}
if(i>0){ #includes possibly outer right interval i=nIvl
a[i]<-(ypts[i+1]-ypts[i])/(xpts[i+1]-xpts[i])
b[i]<-ypts[i]-a[i]*xpts[i]
}
#to do: make more efficient by only updating parts of xstar, w that change
xstar<-get_xstar(xpts,a,b,idx)
#debugging
#if(any(xstar<xpts[idx]) || any(xstar>xpts[idx+1])){
# print(paste0("xstar=",xstar))
# print(paste0("xpts=",xpts))
#}
w<-calc_wts(xpts,xstar,a,b,lo,hi,idx)
}
}
#if(verbose) print(signif(xpts,2))
res
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment