Skip to content

Instantly share code, notes, and snippets.

@TheSalocin
Created September 23, 2020 14:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TheSalocin/a52f493c2d2a90d7b9b73b4caf47f22e to your computer and use it in GitHub Desktop.
Save TheSalocin/a52f493c2d2a90d7b9b73b4caf47f22e to your computer and use it in GitHub Desktop.
Simple Brian2 SNPE adaptation
import brian2 as br
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
br.set_device("cpp_standalone")
eqs = '''
dv/dt = (25-v)/tau : 1 (unless refractory)
'''
rst = """
v=Vr
"""
def run_network(params, N, eqs, rst, time = 20000*br.ms):
"""
builds and runs the network with specified params
args:
params: np.array([tref, tau])
N: number of neurons
eqs: equations governing neuronal dynamics
rst: what to do after a spike
time: the time for which to run the network (in brian2 compatable format)
returns:
histogramm of spike times
"""
Vthr = 20 #spiking threshold
Vr = 10 #reset after spike
tref = params[0]*br.ms #refractory period
tau = params[1]*br.ms #membrane time constant
#define objects of the network
G = br.NeuronGroup(N, model=eqs, threshold="v>Vthr", reset=rst, refractory=tref, method="euler")
Spikes = br.SpikeMonitor(G)
net = br.Network([G, Spikes])
net.run(time)
print("done with a run")
spiketimes = Spikes.t/br.ms
br.device.reinit()
br.device.activate()
#has to be histogramm because number of spikes can vary -> problem broadcasting shapes
return np.histogram(spiketimes,int(10*time/br.ms))[0]
#define the simulator
from delfi.simulator.BaseSimulator import BaseSimulator
class NetworkSim(BaseSimulator):
def __init__(self, N, eqs, rst, time = 20000*br.ms, seed = None):
dim_param = 1
super().__init__(dim_param=dim_param, seed=seed)
self.N = N
self.eqs = eqs
self.rst = rst
self.time = time
self.run_network = run_network
def gen_single(self, params):
assert params.ndim == 1, "parameter dimension must be 1"
#network_seed = self.gen_newseed()
states = self.run_network(params, self.N, self.eqs, self.rst, self.time)
return {"data" : states}
#create prior distributions
import delfi.distribution as dd
seed_p = 2
#range of [tref, tau]
prior_min = np.array([0, 0.1])
prior_max = np.array([20, 100])
prior = dd.Uniform(lower=prior_min, upper=prior_max,seed=seed_p)
#generate network
import delfi.generator as dg
m = NetworkSim(N=1, eqs=eqs, rst=rst, time = 40000*br.ms)
from delfi.summarystats.Identity import Identity
s = Identity()
foo = dg.Default(model = m, prior = prior, summary = s)
#define ground truth simulation
true_params = np.array([2, 20])
labels_params = ["tref", "tau"]
obs = m.gen_single(true_params)
obs_stats = s.calc([obs])
#meta-parameters for SNPE
seed_inf = 1
pilot_samples = 10
# training schedule
n_train = 10
n_rounds = 2
# fitting setup
minibatch = 5
epochs = 10
val_frac = 0.05
# network setup
n_hiddens = [50,50]
# convenience
prior_norm = True
# MAF parameters
density = 'maf'
n_mades = 5 # number of MADES
import delfi.inference as infer
# inference object
res = infer.SNPEC(foo,
obs=obs_stats,
n_hiddens=n_hiddens,
seed=seed_inf,
pilot_samples=pilot_samples,
n_mades=n_mades,
prior_norm=prior_norm,
density=density)
# train
loglik, _, posterior = res.run(
n_train=n_train,
n_rounds=n_rounds,
minibatch=minibatch,
epochs=epochs,
silent_fail=False,
proposal='prior',
val_frac=val_frac,
verbose=True)
#plot the loss
fig = plt.figure(figsize=(15,5))
plt.plot(loglik[0]['loss'],lw=2)
plt.xlabel('iteration')
plt.ylabel('loss');
#plot posterior distribution
from delfi.utils.viz import samples_nd
prior_min = foo.prior.lower
prior_max = foo.prior.upper
prior_lims = np.concatenate((prior_min.reshape(-1,1),prior_max.reshape(-1,1)),axis=1)
posterior_samples = posterior[0].gen(1000)
###################
# colors
hex2rgb = lambda h: tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
# RGB colors in [0, 255]
col = {}
col['GT'] = hex2rgb('30C05D')
col['SNPE'] = hex2rgb('2E7FE8')
col['SAMPLE1'] = hex2rgb('8D62BC')
col['SAMPLE2'] = hex2rgb('AF99EF')
# convert to RGB colors in [0, 1]
for k, v in col.items():
col[k] = tuple([i/255 for i in v])
###################
# posterior
fig, axes = samples_nd(posterior_samples,
limits=prior_lims,
ticks=prior_lims,
labels=labels_params,
#fig_size=(5,5),
diag='kde',
upper='kde',
hist_diag={'bins': 50},
hist_offdiag={'bins': 50},
kde_diag={'bins': 50, 'color': col['SNPE']},
kde_offdiag={'bins': 50},
points=[true_params],
points_offdiag={'markersize': 5},
points_colors=[col['GT']],
title='')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment