Skip to content

Instantly share code, notes, and snippets.

@MitsuhaMiyamizu
Forked from dmbates/SimpleGibbs.md
Last active April 23, 2021 02:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MitsuhaMiyamizu/5edf031a36cfb260381a70060a3fea4a to your computer and use it in GitHub Desktop.
Save MitsuhaMiyamizu/5edf031a36cfb260381a70060a3fea4a to your computer and use it in GitHub Desktop.
Simple Gibbs sampler in Julia

The Gibbs sampler discussed on Darren Wilkinson's blog and also on Dirk Eddelbuettel's blog has been implemented in several languages, the first of which was R.

This gist was written in purpose of demonstration only and may not exist without the help of @Ionizing.

The task is to create a Gibbs sampler for the unscaled density

 f(x,y) = x x^2 \exp(-xy^2 - y^2 + 2y - 4x)

using the conditional distributions

 x|y \sim Gamma(3, y^2 +4)
 y|x \sim Normal(\frac{1}{1+x}, \frac{1}{2(1+x)})

Dirk's version of Darren's original R function is

Rgibbs <- function(N,thin) {
    mat <- matrix(0,ncol=2,nrow=N)
    x <- 0
    y <- 0
    for (i in 1:N) {
        for (j in 1:thin) {
            x <- rgamma(1, shape = 3, rate = y*y+4) #Note: Should be explicitly indicated here as rate or scale to avoid possible obscurity.
            y <- rnorm(1,1/(x+1),1/sqrt(2*(x+1)))
        }
        mat[i,] <- c(x,y)
    }
    mat
}

Dirk also shows the use of the R byte compiler on this function

RCgibbs <- cmpfun(Rgibbs)

In the examples directory of the Rcpp package, Dirk provides an R script using the inline, Rcpp and RcppGSL packages to implement this sampler in C++ code callable from R and time the results. On my desktop computer, timing 10 replications of Rgibbs(20000, 200) and the other versions produces

               test replications elapsed  relative user.self sys.self
4  GSLGibbs(N, thn)           10   8.338  1.000000     8.336    0.000
3 RcppGibbs(N, thn)           10  13.285  1.593308    13.285    0.000
2   RCgibbs(N, thn)           10 369.843 44.356320   369.327    0.032
1    Rgibbs(N, thn)           10 473.511 56.789518   472.754    0.044

A naive translation of Rgibbs to Julia can use the same samplers for the gamma and normal distributions as does R. The C code for R's d-p-q-r functions for probability densities, cumulative distribution, quantile and random sampling can be compiled into a separate Rmath library. These sources are included with the Julia sources and Julia functions with similar calling sequences are available as "extras/Rmath.jl". Please note: if you want to reproduce the results, you should set random seed in Julia with random.seed!($count) where $count is a number you could possibly choose. But every number you use should be explicitly indicated to make others reproduce your results more precisely.

using Rmath
function JGibbs1(N::Int, thin::Int)
    mat = zeros(Float64, N, 2)
    x   = 0.
    y   = 0.
    for i = 1:N
        for j = 1:thin
            x = rgamma(1, 3, 1/(y*y + 4))[1]
            y = rnorm(1, 1/(x+1), 1/sqrt(2(x + 1)))[1]
        end
        mat[i,:] = [x,y]
    end
    mat
end

You can see that JGibbs1 is essentially the same code as Rgibbs with minor adjustments for syntax. A similar timing on the same computer gives

julia> sum([@elapsed JGibbs1(20000, 200) for i=1:10])
27.748079776763916
julia> sum([@elapsed JGibbs1(20000, 200) for i=1:10])
27.782002687454224

which is 17 times faster than Rgibbs and 13 times faster than RCgibbs. It's actually within a factor of 2 of the compiled code in RcppGibbs.

One of the big differences between this function and the compiled C++ function, RcppGibbs, is that the compiled function calls the underlying C code for the samplers directly, avoiding the overhead of creating a vector of length 1 and indexing to get the first element. As these operations are done in the inner loop of JGibbs1 their overhead mounts up.

Fortunately, Julia allows for calling a C function directly. You need the symbol from the library, the signature of the function and the arguments.

It looks like

using Rmath
import Rmath: libRmath
function JGibbs2(N::Int, thin::Int)
    mat = zeros(Float64, N, 2)
    x   = 0.
    y   = 0.
    for i = 1:N
        for j = 1:thin
            x = ccall((:rgamma, libRmath), Float64, (Float64, Float64), 3., 1/(y*y + 4))
            y = ccall((:rnorm, libRmath), Float64, (Float64, Float64), 1/(x+1), 1/sqrt(2*(x + 1)))
        end
        mat[i,:] = [x,y]
    end
    mat
end

The timings are considerably faster, essentially the same as RcppGibbs

julia> sum([@elapsed JGibbs2(20000, 200) for i=1:10])
13.596416234970093
julia> sum([@elapsed JGibbs2(20000, 200) for i=1:10])
13.584651470184326

If we switch to the native Julia random samplers for the gamma and normal distribution, the function becomes

using Distributions
function JGibbs3(N::Int, thin::Int)
    mat = zeros(Float64, N, 2)
    x   = 0.
    y   = 0.
    for i = 1:N
        for j = 1:thin
            x = rand(Gamma(3, 1/(y*y + 4)), 1)[1] #Note: randg() is deprecated in Julia 1.6.0
            y = rand(Normal(1/(x + 1), 1/sqrt(2*(x + 1))), 1)[1]
        end
        mat[i,:] = [x,y]
    end
    mat
end

and the timings are

julia> sum([@elapsed JGibbs3(20000, 200) for i=1:10])
6.603794574737549
julia> sum([@elapsed JGibbs3(20000, 200) for i=1:10])
6.58268928527832

So now we are beating the compiled code from both RcppGibbs (which is using slower samplers) and GSLGibbs (faster samplers but not as fast as those in Julia) while writing code that looks very much like the original R function.

But wait, there's more!

This computer has a 4-core processor (AMD Athlon(tm) II X4 635 Processor @ 2.9 GHz) and Julia can take advantage of that. When starting Julia we specify the number of processes

julia -p 4

and use Julia's tools for parallel execution. An appealing abstraction in Julia is that of "distributed arrays". A distributed array is declared like an array with an element type and dimensions plus two additional arguments: the dimension on which to distribute the array to the different processes and a function that states how each section should be constructed. Usually this is an anonymous function. We will show two versions of the distributed sampler: the first, dJGibbs3a, leaves the result as a distributed array and the second, dJGibbs3b, converts the result to a single array in the parent process.

Pkg.add("BenchmarkTools")
using BenchmarkTools, Random, Distributions
function JGibbs5(N::Int, thin::Int)
  nt = Threads.nthreads()

  # mat = zeros(Float64, N, 2)
  mat = zeros(Float64, N, 2)

  # partition
  parts = Iterators.partition(1:N, N ÷ Threads.nthreads() + 1) |> collect
  
  Threads.@threads for p in parts
    Random.seed!(114514);
    x   = 0.
    y   = 0.
    for i = p
      for j = 1:thin
        x = rand(Gamma(3, 1/(y^2 + 4)))
        y = rand(Normal(1/(x + 1), 1/sqrt(2*(x + 1))))
      end
      mat[i,1] = x
      mat[i,2] = y
    end
  end
  mat
end

Notice that these are one-liners. In Julia, a function consisting of a single expression can be written by giving the signature and that expression, as shown. The anonymous function in dJGibbs3a is declared with the right-pointing arrow construction ->. Its arguments are the type of the array, T, the dimensions of the local chunk, d, and the dimension on which the array is distributed, da. Here we only use the number of rows, d[1], of the chunk to be generated.

The timings,

julia> sum([@elapsed dJGibbs3a(20000, 200) for i=1:10])
1.6914057731628418
julia> sum([@elapsed dJGibbs3a(20000, 200) for i=1:10])
1.6724529266357422
julia> sum([@elapsed dJGibbs3b(20000, 200) for i=1:10])
2.2329299449920654
julia> sum([@elapsed dJGibbs3b(20000, 200) for i=1:10])
2.267037868499756

are remarkable. The speed-up with 4 processes leaving the results as a distributed array, which would be the recommended approach if we were going to do further processing, is essentially 4x. This is because there is almost no communication overhead. When converting the results to a (non-distributed) array, the speed-up is 3x.

If you haven't looked into Julia before now, you owe it to yourself to do so.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment