Skip to content

Instantly share code, notes, and snippets.

@russelljjarvis
Last active May 17, 2023 04:55
Show Gist options
  • Save russelljjarvis/3ac699d028f29f1e4f7d5052196930ba to your computer and use it in GitHub Desktop.
Save russelljjarvis/3ac699d028f29f1e4f7d5052196930ba to your computer and use it in GitHub Desktop.
GPUComplianceModifications.jl
using Revise
using CUDA
"""
# * Changes to the cells struct as follows:
# * Make the typing used dynamic and parametric, ie types can be either CuArray or regular Array, depending the input
# arguments that are passed to the constructor.
# TODO, I have good things about KernelAbstractions.jl it possible that kernel abstractions will remove my reliance on different method dispatch to choose CPU/GPU backend.
# * I added in refractory periods.
# * I made examples of CUDA kernel calls for cell Vm update integrate, note it is not faster yet, but I am working on it.
# * Stored the post synaptic targets of each pre-synaptic cell populations cell struct for the reason described below
## TODO, as needed:
# * Since each cuda kernel has a cost, do as much work as possible per cuda kernel call.
# to do as much work as possible in a cuda kernel, use code macros as a way to update both: `$V_{m} $`
# and weights together in the same loop, to make the folding together of synaptic weight updates with Vm updates, store the post
# synaptic targets of each pre-synaptic cell populations cell struct.
Note the CUDA kernels execute and produce results, but they are not yet as fast, or faster than CPU. To make the CUDA kernels
fast and merited I would explore 2D threads (using `.x`, and `.y`). Folding Vm update togethor with synaptic update, to maximise
work done in each cuda kernel call.
Example of a LIF update as a CUDA kernel
"""
function lif_kernel!(N, v, ge, gi, fire, u,dt,g,tr)
τm, τe, τi, Vt, Vr, El, gL = (100.0,5.0,10.0,0.2,0.0,-60.0,10.0)
R = 1.75
i0 = threadIdx().x + (blockIdx().x - 1) * blockDim().x
stride = blockDim().x
g = g + (ge + gi)
tref = 10.0
@inbounds for i=i0:stride:N
v[i] = v[i] * exp(-dt /τm) + Vr +X+g[i]+u[i]
v[i] += (u[i]+g[i]) * (R/ τm)
ge[i] += dt * -ge[i] / τe
gi[i] += dt * -gi[i] / τi
g[i] += dt * -g[i] / (τi+τe)
if tr[i] > 0 # check if in refractory period
v[i] = vSS # set voltage to reset
tr[i] = tr[i] - 1 # reduce running counter of refractory period
elseif v[i] > θ
fire[i] = v[i] > θ
tr[i] = Int(round(tref*dt)) # set refractory time
end
end
nothing
end
"""
Example of the multidispatch integrate function that is chosen to execute when `fire` is of CuArray type.
"""
function integrate!(N::Integer,v::CuArray,dt::Real,ge::CuVector,gi::CuVector,fire::CuArray{Bool},u::CuVector,g::CuVector)
kernel = @cuda launch=false lif_kernel!(N, v, ge, gi, fire,u,dt,g,tr)
config = launch_configuration(kernel.fun)
xthreads = min(32, N)
xblocks = min(config.blocks, cld(N, xthreads))
kernel(N, v, ge, gi, fire, u,dt;threads=(xthreads), blocks=(xblocks<<2))
end
abstract type AbstractIFNF end
"""
A population of cells
"""
mutable struct IFNF{C<:Integer,Q<:AbstractArray{<:Bool},L<:AbstractVecOrMat{<:Real}} <: AbstractIFNF
N::C
v::L
ge::L
gi::L
fire::Q
u::L
tr::L
records::Dict
post_synaptic_targets::Array{Array{UInt64}} # SVector
function IFNF(N,v,ge,gi,fire,u,tr,records,post_synaptic_targets)
new{typeof(N),typeof(fire),typeof(ge)}(N,v,ge,gi,fire,u,tr,records,post_synaptic_targets)
end
function IFNF(N,fire,u,sim_type::Array)
v = typeof(sim_type)(ones(N).-55.0)
g = typeof(sim_type)(zeros(N))
ge = typeof(sim_type)(zeros(N))
gi = typeof(sim_type)(zeros(N))
tr = zeros(typeof(N),N)
post_synaptic_targets = Array{Array{UInt64}}(undef,N)
for i in 1:N
post_synaptic_targets[i] = Array{UInt64}([])
end
#post_synaptic_targets = SVector{N, Array{UInt32}}(post_synaptic_targets)
pre_synaptic_weights = Vector{Float32}(zeros(N))
records::Dict = Dict()
IFNF(N,v,ge,gi,fire,u,tr,records,post_synaptic_targets)
end
function IFNF(N,fire,u,post_synaptic_targets::Vector{Array{UInt64}})
v = typeof(u)(ones(N).-55.)
g = typeof(u)(zeros(N))
ge = typeof(u)(zeros(N))
gi = typeof(u)(zeros(N))
tr = zeros(typeof(N),N)
records::Dict = Dict()
IFNF(N,v,ge,gi,fire,u,tr,records,post_synaptic_targets)
end
function IFNF(N,sim_type::CuArray,post_synaptic_targets)
fire::CuArray{Bool} = zeros(Bool,N)
u = typeof(sim_type)(zeros(N))
IFNF(N,fire,u,post_synaptic_targets)
end
function IFNF(N,sim_type::Array,post_synaptic_targets)
fire::Array{Bool} = zeros(Bool,N)
u = typeof(sim_type)(zeros(N))
IFNF(N,fire,u,post_synaptic_targets)
end
end
"""
[Uniform parameter: Integrate-And-Fire Neuron](https://neuronaldynamics.epfl.ch/online/Ch1.S3.html)
"""
IFNF
function integrate!(p::IFNF, dt::Float32)
@unpack N, v, ge, gi, fire, u, tr = p
integrate!(N, v, dt, ge, gi, fire, u, tr)
end
function integrate!(N::Integer,v::CuArray,dt::Real,ge::CuVector,gi::CuVector,fire::CuArray{Bool},u::CuVector,g::CuVector)
kernel = @cuda launch=false lif_kernel!(N, v, ge, gi, fire,u,dt,g,tr)
config = launch_configuration(kernel.fun)
xthreads = min(32, N)
xblocks = min(config.blocks, cld(N, xthreads))
kernel(N, v, ge, gi, fire, u,dt;threads=(xthreads), blocks=(xblocks<<2))
end
function integrate!(N::Integer,v::Vector,dt::Real,ge::Vector,gi::Vector,fire::Vector{Bool},u::Vector{<:Real},tr::Vector{<:Number})
τe, τi = 5.0,10.0
τ::Real = 8.
R::Real = 10.
θ::Real = -50.
vSS::Real =-55.
v0::Real = -100.
tref = 10.0
#dt dt*ms
@inbounds for i = 1:N
ge[i] += dt * -ge[i] / τe
gi[i] += dt * -gi[i] / τi
g = ge[i] + gi[i]
v[i] = v[i] + (g+u[i]) * R / τ
v[i] += (dt/τ) * (-v[i] + vSS)# *V
if tr[i] > 0 # check if in refractory period
v[i] = vSS # set voltage to reset
tr[i] = tr[i] - 1 # reduce running counter of refractory period
elseif v[i] > θ
fire[i] = v[i] > θ
tr[i] = Int(round(tref*dt)) # set refractory time
end
end
end
function lif_kernel!(N, v, ge, gi, fire, u,dt,g,tr)
τm, τe, τi, Vt, Vr, El, gL = (100.0,5.0,10.0,0.2,0.0,-60.0,10.0)
R = 1.75
i0 = threadIdx().x + (blockIdx().x - 1) * blockDim().x
stride = blockDim().x
g = g + (ge + gi)
tref = 10.0
@inbounds for i=i0:stride:N
v[i] = v[i] * exp(-dt /τm) + Vr +X+g[i]+u[i]
v[i] += (u[i]+g[i]) * (R/ τm)
ge[i] += dt * -ge[i] / τe
gi[i] += dt * -gi[i] / τi
g[i] += dt * -g[i] / (τi+τe)
if tr[i] > 0 # check if in refractory period
v[i] = vSS # set voltage to reset
tr[i] = tr[i] - 1 # reduce running counter of refractory period
elseif v[i] > θ
fire[i] = v[i] > θ
tr[i] = Int(round(tref*dt)) # set refractory time
end
end
nothing
end
Base.show(io::IO, ::MIME"text/plain") =
print(io, """LIF:
voltage: $(neuron.v)
current: $(neuron.u)
τm: $(neuron.τm)
Vr : $(neuron.Vr)""")
Base.show(io::IO, neuron::IFNF) =
print(io, "LIF(voltage = $(neuron.v), current = $(neuron.u))")
"""
LIF(τm, vreset, vth, R = 1.0)
Create a LIF neuron with zero initial voltage and empty current queue.
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment