Created
September 23, 2020 14:01
-
-
Save TheSalocin/a52f493c2d2a90d7b9b73b4caf47f22e to your computer and use it in GitHub Desktop.
Simple Brian2 SNPE adaptation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import brian2 as br | |
import matplotlib.pyplot as plt | |
%matplotlib inline | |
import numpy as np | |
br.set_device("cpp_standalone") | |
eqs = ''' | |
dv/dt = (25-v)/tau : 1 (unless refractory) | |
''' | |
rst = """ | |
v=Vr | |
""" | |
def run_network(params, N, eqs, rst, time = 20000*br.ms): | |
""" | |
builds and runs the network with specified params | |
args: | |
params: np.array([tref, tau]) | |
N: number of neurons | |
eqs: equations governing neuronal dynamics | |
rst: what to do after a spike | |
time: the time for which to run the network (in brian2 compatable format) | |
returns: | |
histogramm of spike times | |
""" | |
Vthr = 20 #spiking threshold | |
Vr = 10 #reset after spike | |
tref = params[0]*br.ms #refractory period | |
tau = params[1]*br.ms #membrane time constant | |
#define objects of the network | |
G = br.NeuronGroup(N, model=eqs, threshold="v>Vthr", reset=rst, refractory=tref, method="euler") | |
Spikes = br.SpikeMonitor(G) | |
net = br.Network([G, Spikes]) | |
net.run(time) | |
print("done with a run") | |
spiketimes = Spikes.t/br.ms | |
br.device.reinit() | |
br.device.activate() | |
#has to be histogramm because number of spikes can vary -> problem broadcasting shapes | |
return np.histogram(spiketimes,int(10*time/br.ms))[0] | |
#define the simulator | |
from delfi.simulator.BaseSimulator import BaseSimulator | |
class NetworkSim(BaseSimulator): | |
def __init__(self, N, eqs, rst, time = 20000*br.ms, seed = None): | |
dim_param = 1 | |
super().__init__(dim_param=dim_param, seed=seed) | |
self.N = N | |
self.eqs = eqs | |
self.rst = rst | |
self.time = time | |
self.run_network = run_network | |
def gen_single(self, params): | |
assert params.ndim == 1, "parameter dimension must be 1" | |
#network_seed = self.gen_newseed() | |
states = self.run_network(params, self.N, self.eqs, self.rst, self.time) | |
return {"data" : states} | |
#create prior distributions | |
import delfi.distribution as dd | |
seed_p = 2 | |
#range of [tref, tau] | |
prior_min = np.array([0, 0.1]) | |
prior_max = np.array([20, 100]) | |
prior = dd.Uniform(lower=prior_min, upper=prior_max,seed=seed_p) | |
#generate network | |
import delfi.generator as dg | |
m = NetworkSim(N=1, eqs=eqs, rst=rst, time = 40000*br.ms) | |
from delfi.summarystats.Identity import Identity | |
s = Identity() | |
foo = dg.Default(model = m, prior = prior, summary = s) | |
#define ground truth simulation | |
true_params = np.array([2, 20]) | |
labels_params = ["tref", "tau"] | |
obs = m.gen_single(true_params) | |
obs_stats = s.calc([obs]) | |
#meta-parameters for SNPE | |
seed_inf = 1 | |
pilot_samples = 10 | |
# training schedule | |
n_train = 10 | |
n_rounds = 2 | |
# fitting setup | |
minibatch = 5 | |
epochs = 10 | |
val_frac = 0.05 | |
# network setup | |
n_hiddens = [50,50] | |
# convenience | |
prior_norm = True | |
# MAF parameters | |
density = 'maf' | |
n_mades = 5 # number of MADES | |
import delfi.inference as infer | |
# inference object | |
res = infer.SNPEC(foo, | |
obs=obs_stats, | |
n_hiddens=n_hiddens, | |
seed=seed_inf, | |
pilot_samples=pilot_samples, | |
n_mades=n_mades, | |
prior_norm=prior_norm, | |
density=density) | |
# train | |
loglik, _, posterior = res.run( | |
n_train=n_train, | |
n_rounds=n_rounds, | |
minibatch=minibatch, | |
epochs=epochs, | |
silent_fail=False, | |
proposal='prior', | |
val_frac=val_frac, | |
verbose=True) | |
#plot the loss | |
fig = plt.figure(figsize=(15,5)) | |
plt.plot(loglik[0]['loss'],lw=2) | |
plt.xlabel('iteration') | |
plt.ylabel('loss'); | |
#plot posterior distribution | |
from delfi.utils.viz import samples_nd | |
prior_min = foo.prior.lower | |
prior_max = foo.prior.upper | |
prior_lims = np.concatenate((prior_min.reshape(-1,1),prior_max.reshape(-1,1)),axis=1) | |
posterior_samples = posterior[0].gen(1000) | |
################### | |
# colors | |
hex2rgb = lambda h: tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) | |
# RGB colors in [0, 255] | |
col = {} | |
col['GT'] = hex2rgb('30C05D') | |
col['SNPE'] = hex2rgb('2E7FE8') | |
col['SAMPLE1'] = hex2rgb('8D62BC') | |
col['SAMPLE2'] = hex2rgb('AF99EF') | |
# convert to RGB colors in [0, 1] | |
for k, v in col.items(): | |
col[k] = tuple([i/255 for i in v]) | |
################### | |
# posterior | |
fig, axes = samples_nd(posterior_samples, | |
limits=prior_lims, | |
ticks=prior_lims, | |
labels=labels_params, | |
#fig_size=(5,5), | |
diag='kde', | |
upper='kde', | |
hist_diag={'bins': 50}, | |
hist_offdiag={'bins': 50}, | |
kde_diag={'bins': 50, 'color': col['SNPE']}, | |
kde_offdiag={'bins': 50}, | |
points=[true_params], | |
points_offdiag={'markersize': 5}, | |
points_colors=[col['GT']], | |
title='') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment