Skip to content

Instantly share code, notes, and snippets.

@HGangloff
HGangloff / equivalence_gradEM_gradLlkh.py
Created September 25, 2021 07:02
Equivalence between a gradient ascent over the Expectation Maximization quantity Q and a gradient ascent over the model likelihood in the case of training an Hidden Markov Chain with Gaussian Independent Noise
'''
Equivalence between a gradient ascent over the Expectation Maximization
quantity Q and a gradient ascent over the model likelihood in the case of
training an Hidden Markov Chain with Gaussian Independent Noise
'''
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
import jax
from jax.scipy.stats import norm
@HGangloff
HGangloff / gmrf_simulation.py
Last active September 19, 2021 09:51
Gaussian Markov Random Field simulation using Fourier properties and base matrices for efficiency
'''
Gaussian Markov Random Field simulation using Fourier properties and base matrices for efficiency
References from Gaussian Markov Random Fields: Theory and Applications, Havard Rue and Leonhard Held
We treat the 2D case, with 0 mean, stationary variance and exponential correlation function
'''
import matplotlib.pyplot as plt
import numpy as np
from scipy.fftpack import fft2, ifft2
@HGangloff
HGangloff / forward_backward_hmcin_ctypes.py
Created August 17, 2021 16:53
Simple and extremely fast implementation of rescaled Forward Backward algorithm for Hidden Markov Chains with Independent Gaussian Noise in Python calling C using ctypes and creating a shared library. To compile just type 'make all' at the root of the directory
import numpy as np
from scipy.stats import norm
from ctypes import CDLL, c_int, c_double, c_char, c_void_p
def generate_observations(H, means, stds):
X = ((H == 0) * (means[0] + np.random.randn(*H.shape) * stds[0]) +
(H == 1) * (means[1] + np.random.randn(*H.shape) * stds[1]))
return X
@HGangloff
HGangloff / SIR_particle_filter_jax.py
Last active August 20, 2021 10:16
Efficient Sequential Importance Resampling Particle Filter with Jax using jit and lax.scan
'''
Sequential Importance Resampling Particle Filter with Jax using jit and
lax.scan
The model equation for this simple application are:
X_{n} = 0.5 * X_{n-1} + 25 * X_{n-1} / (1 + X_{n-1}^2) + 8 * cos(1.2 * n) + U
Y_{n} = X_{n}^2 / 20 + V
where U~N(0, 10) and V~N(0, 1)
'''
import matplotlib.pyplot as plt
@HGangloff
HGangloff / chromatic_gibbs_sampler_jax.py
Created August 3, 2021 09:50
Efficient chromatic Gibbs sampler for a binary Ising Markov random field with Jax jit, vmap and lax.scan
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
from jax import vmap, jit
from jax.scipy.signal import convolve2d
def color_image_graph(lx, ly, neigh_list):
def first_available(color_list):
@HGangloff
HGangloff / forward_backward_hmcin_jax.py
Last active August 18, 2021 09:56
Simple implementation of rescaled Forward Backward algorithm for Hidden Markov Chains with Independent Gaussian Noise in Jax using the lax.scan function
import numpy as np
import jax.numpy as jnp
import jax
from jax.scipy.stats import norm
def generate_observations(H, means, stds):
X = ((H == 0) * (means[0] + np.random.randn(*H.shape) * stds[0]) +
(H == 1) * (means[1] + np.random.randn(*H.shape) * stds[1]))
return X