Skip to content

Instantly share code, notes, and snippets.

@rseydam
Created June 10, 2022 05:47
Show Gist options
  • Save rseydam/7b141c52bbcecf0811426b6a54bf3658 to your computer and use it in GitHub Desktop.
Save rseydam/7b141c52bbcecf0811426b6a54bf3658 to your computer and use it in GitHub Desktop.
# 2 population spiking model
# two coupled population with convolution inh <-> exc
# structures, stepper, plotting
###
using Setfield
using Parameters
using Random
using SparseArrays
using LinearAlgebra
using NNlib #gives a bunch of activation function also used by flux.jl
using FFTW
using JLD2
##
#heaviside function
hs(x::AbstractFloat) = ifelse(x < 0, zero(x), one(x))
hs2(x::AbstractFloat) = ifelse(x <= 0, zero(x), one(x))
# connectivity matrix in shifted coordinates for convolutions
function connectivity(n, l, w)
conn = zeros(n*n) # initialize weights (here as a long vector)
shifted = (fftfreq(n)*n) # get shifted coordinates
##
for i in 1:n
for j in 1:n
c = n * (i-1) + j # coordinate relation matrix to vector
r = sqrt(shifted[i]*shifted[i] + shifted[j]*shifted[j]) # compute distance
if r <= l # check whether r below given coupling range l
conn[c] = w # at the moment the sign is according to the w_ex w_in values
#here we could also use different kernel for grated synapses!
else
conn[c] = 0.0
end
end
end
return reshape(conn,(n,n)) # return as nxn matrix
end
##
# 2 population spiking model
# with fixed typing here
@with_kw struct SPnet2
#
nsq::Int64 = 64 # n is made a perfect square to have a 2d grid
n::Int64 = nsq*nsq # neural sheet size
dt::Float64 = 1.0 # timestep size/simulation step
t::Vector{UInt64} = [0] # time index of the current step t[1]*dt = time
τex::Float64 = 40.0 # neuron time constant
τin::Float64 = 20.0 # neuron time constant
l_in::Float64 = 12.0 # range of inhibition
w_in::Float64 = -1.0 # strength & sign inhibition
l_ex::Float64 = 4.0 # range of excitation
w_ex::Float64 = 1.0 # strength excitation
delay_in::Float64 = 2.0 # synaptic delay
delay_etoi::Float64 = 2.0 # synaptic delay
delay_etoe::Float64 = 5.0 # synaptic delay
a_ex::Float64 = 1.1 # additional external drive
a_in::Float64 = 0.8 # additional external drive
a_m::Float64 = 0.0 # driving amplitude
ω_m::Float64 = 0.0 # driving frequency
d_0::Float64 = 0.0 #spatially heterogeneous driving
λ_B::Float64 = 0.0 # d(x,t) = d_0*cos(x/λ_B *2π + ϕ_t + Θ)
Θ::Float64 = 0.0 # later take 2π/λ_B
#seed and rng, noise parameters
rngseed::UInt64 = rand(UInt64) # random initial seed as default
rng::MersenneTwister = MersenneTwister(rngseed) # setting rng
rngseedIC::UInt64 = rand(UInt64) # random initial seed as default
rngIC::MersenneTwister = MersenneTwister(rngseedIC) # setting rng for IC
σ_ex::Float64 = 0.01 # diffusion in ex -- careful do not reset after generation/need to reset ns_ex too
σ_in::Float64 = 0.01 # diffusion in in -- careful do not reset after generation/need to reset ns_in too
ns_ex::Float64 = σ_ex * sqrt(dt/τex) # prefactor for euler step
ns_in::Float64 = σ_in * sqrt(dt/τin)
# create the fourier transformed coupling matrices / only done once
w_four_ex::Matrix{ComplexF64} = fft( connectivity(nsq, l_ex, w_ex) )
w_four_in::Matrix{ComplexF64} = fft( connectivity(nsq, l_in, w_in) )
# we could also make this w_four[1:2]... could be helpfull dealing with more layers
# note that in FFTW.jl normalization is fft^-1(fft(x)) = x
# distances in the history register
hist_in::Int64 = max(round(delay_in / dt), 1)
hist_etoi::Int64 = max(round(delay_etoi / dt), 1)
hist_etoe::Int64 = max(round(delay_etoe / dt), 1)
# the maximum will decide on the size of the history storage
# calculate length of history interval
hist_size::Int64 = convert(Int,max(max(hist_in, hist_etoi), hist_etoe))
# in-place transform plans and storage
s_temp::Matrix{ComplexF64} = fft(zeros(nsq,nsq)) # temporary spike matrix used for in-place operations (complex matrix)
fw_plan! = plan_fft!( s_temp; flags=FFTW.PATIENT, timelimit=Inf) # operates in-place on s_temp
rv_plan! = plan_ifft!( s_temp ; flags=FFTW.PATIENT, timelimit=Inf)
#note: these don't make use of the fact that we transform a real function that would need
#only half the space but rfft! doesn't exists unfortunately - it is still many times faster than
#recreating the arrays so i use regular fft! (in-place)
#history for inh and exc
sw_i::Array{Float64, 3} = zeros(nsq, nsq, hist_size)
sw_e::Array{Float64, 3} = zeros(nsq, nsq, hist_size)
#we can also make this sw[1:2] ...
#current_hist / at initialization
hist_idx::Vector{Int64} = [1,1,1,1] # hist_now, past_in, past_etoi, past_etoe -> history 'pointers'
#Allocate space for spikes convolved with kernels
sw_to_ex::Matrix{Float64} = zeros(nsq,nsq) # collecting all input to exc pop from coupling
sw_to_in::Matrix{Float64} = zeros(nsq,nsq) # collecting all input to inh pop from coupling
#for more layers also sw_to[1,2] ...
# inhibitory, excitatory
psi::Vector{Matrix{Float64}} = [ rand(rngIC,nsq,nsq), rand(rngIC,nsq,nsq) ] # membrane potentials
s::Vector{Matrix{Float64}} = [ zeros(nsq,nsq), zeros(nsq,nsq) ] # current spikes ('active' after latest step)
end
##
##
# this function performes one iteration step of the spiking model
function SPnet2Step!(net,ϕ_t)
@unpack psi, s, τex, τin, n, nsq, t,
dt, w_four_ex, w_four_in, a_ex, a_in,
fw_plan!, rv_plan!,
hist_in, hist_etoi, hist_etoe, hist_size, hist_idx,
s_temp, sw_to_in, sw_to_ex, sw_i, sw_e,
rng, ns_ex, ns_in,
a_m, ω_m, d_0, λ_B, Θ = net
# obtain appropriate spike-coupling-convolution indices
hist_idx[2] = mod(hist_idx[1] - hist_in , hist_size) + 1 # past_in
hist_idx[3] = mod(hist_idx[1] - hist_etoi, hist_size) + 1 # past_etoi
hist_idx[4] = mod(hist_idx[1] - hist_etoe, hist_size) + 1 # past_etoe
# Retrieve spike convolutions from the appropriate history register
sw_to_in .= @view sw_e[:,:,hist_idx[3]] # excitatory to inhibitory
sw_to_ex .= @view sw_i[:,:,hist_idx[2]] # inhibitory to excitatory
sw_to_ex .+= @view sw_e[:,:,hist_idx[4]] # excitatory to excitatory
# iterate the system forward in time
psi[1] .+= ( .-psi[1] .+ a_in .+ a_m*sin(ω_m * t[1]*dt) ) .* (dt/τin) .+
randn.(rng) .* ns_in .+ sw_to_in ./τin
#add forcing in 'vertical' direction
# j'th column psi[1][:,j]
if d_0 != 0.0
for j in 1:nsq #for each column add forcing to each unit (both populations)
psi[1][:,j] .+= d_0*cos(j*2π/λ_B + Θ + ϕ_t) .* (dt/τin)
psi[2][:,j] .+= d_0*cos(j*2π/λ_B + Θ + ϕ_t) .* (dt/τex)
end
end
psi[2] .+= ( .-psi[2] .+ a_ex .+ a_m*sin(ω_m * t[1]*dt) ) .* (dt/τex) .+
randn.(rng) .* ns_ex .+ sw_to_ex ./τex
#println("current hist pointer:", hist_idx[1])
for p in eachindex(psi)
# setting lower bound and spikes psis
map!(psi[p], @view psi[p][:]) do x
if x >= 1.0
x = 0.0
elseif x < -2.0
x = -2.0
end
return x
end
# identify spikes psi[p][:] == 0.0
map!(s[p], @view psi[p][:] ) do x
if x == 0.0 #if for any reason x didn't change we detect spike although no occured!!!
x = 1.0
else
x = 0.0
end
return x
end
#assign temporary spikes
s_temp .= s[p] .+ 0.0*im # i am not sure if i even need them maybe using s[p] is enough
#println("sum s_temp of population $p: ",sum(s_temp))
#Convolve spikes with kernels and load into the proper sw register
if p==1 #sw_i[:,:,hist_idx[1]] .= rv_plan * ( ( fw_plan * s[p] ) .* w_four_in )
fw_plan! * s_temp # forward fft in-place of s_temp
s_temp .*= w_four_in # elementwise product with coupling matrix in fourier space
rv_plan! * s_temp # reverse transform in-place
sw_i[:,:,hist_idx[1]] .= real.(s_temp) # copy convolved spikes to sw register
else # sw_e[:,:,hist_idx[1]] .= rv_plan * ( ( fw_plan * s[p] ) .* w_four_ex )
fw_plan! * s_temp
s_temp .*= w_four_ex
rv_plan! * s_temp
sw_e[:,:,hist_idx[1]] .= real.(s_temp) # copy convolved spikes to sw register
end
end
#advance time counter
t[1] += 1
hist_idx[1] = (hist_idx[1] + 1) % hist_size # update history 'pointer'
if hist_idx[1]==0 hist_idx[1] = hist_size end
return nothing
end
##
##
# reinitialize the system
function reinit_IC_same_rng(spnet)
@unpack s, psi, nsq, rngseed, rngseedIC, rng, rngIC, sw_i, sw_e, hist_idx, hist_size = spnet
Random.seed!(rng, rngseed) #reinits rng
Random.seed!(rngIC, rngseed) #reinits rng
#here we can set the exact same IC as previously used
psi .= [ rand(rngIC,nsq,nsq), rand(rngIC,nsq,nsq) ]
s .= [ zeros(nsq,nsq), zeros(nsq,nsq) ]
#also reset history intervals and 'pointer'
sw_i .= zeros(nsq, nsq, hist_size)
sw_e .= zeros(nsq, nsq, hist_size)
hist_idx .= [1,1,1,1]
return nothing
end
##
# here is another piece of code that i found important
# when you want to store a structure with FFT plans they somehow break julia when you load them up
# i didn't have the nerve to figure out why so i just wrote a function to load up the stored
# structure and reinitialize the FFT plans
# loading system from disc (jdl2 file) require special care with FFTplans
function load_system(fn,sn) "fn is the name of jdl2 file, sn is the object name"
spnet = load(fn, sn); # this is the load function from
# we need to create the plans new because they are corrupted and crash julia
spnet = @set spnet.fw_plan! = plan_fft!( spnet.s_temp; flags=FFTW.PATIENT, timelimit=Inf);
spnet = @set spnet.rv_plan! = plan_ifft!( spnet.s_temp; flags=FFTW.PATIENT, timelimit=Inf);
return spnet
end
# for storing and loading i use JLD2.jl which is a great package really
# i also use it to store spiking data of the system to disc in jld2 format
# for instance that allows me to easily store sparse matrices
@tfiers
Copy link

tfiers commented Jun 21, 2022

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment