Created
January 6, 2017 03:05
-
-
Save anonymous/e12b0537d33bdf25cc890c266df22b31 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# LIBRARY IMPORT | |
from numpy import * | |
from random import * | |
import pylab as pylab | |
from math import * | |
# | |
# BRIDGING DENSITY starting from the flat density | |
# | |
def density_bridge(x,y,center_coord, std, k,N): | |
if k == 0: | |
return 1.0 | |
return pow(density(x,y, center_coord, std), k/N) | |
# | |
# mixture of Gaussian: center_coord = centres of the Gaussians | |
# | |
def density(x,y, center_coord, std): | |
N = len(center_coord) | |
likelihood = 0 | |
for c_x,c_y in center_coord: | |
likelihood += exp(-(pow(x-c_x,2) + pow(y-c_y,2))/(2*std*std)) | |
return likelihood | |
# | |
# 2-dimensional random walk proposals | |
# | |
def proposal(x,y,delta): | |
return (gauss(x,delta), gauss(y,delta)) | |
# | |
# compute Effective Sampling Size (ESS) | |
# | |
def compute_ESS(weights): | |
ess = sum([w*w for w in weights]) | |
return 1.0/ess | |
def resampling(particles, weights): | |
# check that the weights are normalized | |
if (sum(weights)<0.999) or (sum(weights)>1.001): | |
print "ERROR: !! weights are not normalized !!" | |
N = len(particles); particles_new = range(N) | |
w_sum = weights[0]; current_index = 0; current = random()/N | |
for k in range(N): | |
while current > w_sum: | |
current_index += 1; w_sum += weights[current_index] | |
particles_new[k] = particles[current_index] | |
current += 1.0/N | |
return particles_new | |
# | |
# SMC sampler | |
# | |
def smc(particles, delta, N_temperature, center_coord, std, save_file = False): | |
weights = [1.0 for p in particles] | |
N_particles = len(particles) | |
for t in arange(1,N_temperature+1): | |
print "iteration ",t, " out of ", N_temperature | |
# evolve each particle with pi_t as target | |
for k in range(N_particles): | |
x,y = particles[k] | |
x_new, y_new = proposal(x,y,delta) | |
#compute log acceptance (for more stability) | |
L_accept = log(density(x_new,y_new, center_coord, std)) - log(density(x,y, center_coord, std)) | |
if log(random()) < L_accept: | |
particles[k] = (x_new, y_new) | |
L_W = log(density(x,y, center_coord, std)) | |
L_W = L_W/N_temperature | |
weights[k] = weights[k] * exp(L_W) | |
# renormalise weights | |
w_sum = sum(weights); weights = [w/w_sum for w in weights] | |
ESS = compute_ESS(weights) | |
print " ESS=", ESS | |
# resampling is ess < N/2 | |
if ESS < N_particles/2: | |
particles = resampling(particles, weights) | |
weights = [1.0 for w in weights] | |
if save_file: | |
index = t | |
name = "smc_sampler_" + str(1000+index) | |
pylab.clf() | |
radius = 30 | |
frame1 = pylab.gca() | |
frame1.axes.get_xaxis().set_visible(False) | |
frame1.axes.get_yaxis().set_visible(False) | |
pylab.axis([-radius,radius,-radius,radius]) | |
particles_x = [x for (x,y) in particles] | |
particles_y = [y for (x,y) in particles] | |
pylab.plot(particles_x, particles_y, "ro", alpha=0.3) | |
center_x = [x for (x,y) in center_coord] | |
center_y = [y for (x,y) in center_coord] | |
pylab.plot(center_x, center_y, "bo", ) | |
pylab.savefig(name) | |
return resampling(particles, weights) | |
# | |
# CREATE TARGET + BRIDGE DENSITIES | |
# | |
center_coord = []; std = 1.0; radius = 20 | |
for theta in arange(0, 2*3.1415, 2*3.1415/7): | |
center_coord.append( (radius*cos(theta), radius*sin(theta)) ) | |
# | |
# CREATE PARTICLES SYSTEM | |
# | |
N_particles = 3000; N_temperature = 100 | |
particles = [(0.0, 0.0) for k in range(N_particles)] | |
# | |
# EVOLVE PARTICLES | |
# | |
delta = 0.6; save_file = True | |
particles = smc(particles, delta, N_temperature, center_coord, std, save_file) | |
particles_x = [x for (x,y) in particles] | |
particles_y = [y for (x,y) in particles] | |
# | |
# PLOT THE RESULTS | |
# | |
pylab.figure() | |
pylab.plot(particles_x, particles_y, "ro",alpha=0.7) | |
pylab.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment