Skip to content

Instantly share code, notes, and snippets.

@wsphillips
Last active November 17, 2022 20:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wsphillips/91ebcb33d4efcd6eb2ce7735ffa551b3 to your computer and use it in GitHub Desktop.
Save wsphillips/91ebcb33d4efcd6eb2ce7735ffa551b3 to your computer and use it in GitHub Desktop.
Flux LIF neuron
using Flux, Distributions
# Layer for LIF neurons:
struct LIFNeurons
dt
W_inputs # size(in) array input weights
W_rex_inputs # length(in)xlength(in) array recurrent excitation weights
tau_mem
tau_syn
end
function LIFNeurons(n::Integer; dt=1, tau_mem=10, tau_syn=10)
LIFNeurons(dt, rand(n), rand(n,n), tau_mem, tau_syn)
end
# 1 input => 1 output + 3 recurrent hidden states fed back
function (lif::LIFNeurons)((S,V,I), x)
(; dt, tau_mem, tau_syn) = lif
layer_input = lif.W_inputs .* x
rex_input = reshape(sum(lif.W_rex_inputs .* S, dims=1), :)
I_next = I .+ (dt/tau_syn) .* (-I .+ layer_input .+ rex_input)
V_ = V .+ (dt/tau_mem) .* (-V .+ I_next)
S_next = V_ .>= 1.0
V_next = (1 .- S_next) .* V_
return (S_next, V_next, I_next), S_next
end
N = 10 # number of neurons
p = Poisson(4)
spiking_lif_layer = Flux.Recur(LIFNeurons(N), (falses(N), zeros(N), zeros(N)))
spiking_lif_layer(rand(p, N))
@avik-pal
Copy link

avik-pal commented Nov 7, 2022

Here's a Lux version:

using Lux, Random

struct LIFNeurons <: Lux.AbstractExplicitLayer
    dt
    tau_mem
    tau_syn
    n
end

function LIFNeurons(n::Integer; dt=1, tau_mem=10, tau_syn=10)
    LIFNeurons(dt, tau_mem, tau_syn, n)
end

function Lux.initialparameters(rng::AbstractRNG, l::LIFNeurons)
     return (; W_inputs = rand(rng, l.n), W_rex_inputs = rand(rng, l.n, l.n))
end

# 1 input => 1 output + 3 recurrent hidden states fed back
function (lif::LIFNeurons)((x, (S, V, I)), ps, st)
    (; dt, tau_mem, tau_syn) = lif
    layer_input = ps.W_inputs .*  x
    rex_input = reshape(sum(ps.W_rex_inputs .* S, dims=1), :)
    V_next = exp(-(dt/tau_mem)) .* V .+ I .- S
    I_next = exp(-(dt/tau_syn)) .* I .+ layer_input .+ rex_input
    S_next = V_next .>= 1.0
    return (S_next, (S_next,V_next,I_next)), st
end

function (lif::LIFNeurons)(x::AbstractVector, ps, st)
    return lif((x, (zeros(lif.n), zeros(lif.n), zeros(lif.n))), ps, st)
end

N = 10
spiking_lif_layer = StatefulRecurrentCell(LIFNeurons(N))
ps, st = Lux.setup(Random.default_rng(), spiking_lif_layer)

spiking_lif_layer(rand(0:1, N), ps, st)

@ChauhanT
Copy link

ChauhanT commented Nov 8, 2022

Awesome!

If we will assume dt to be very small we can approximate to something like:

V_next = V .+ (dt/tau_mem) .* (-V .+ I)
I_next = I .+ (dt/tau_syn) .* (-I .+ layer_input .+ rex_input)
S_next = V_next .>= 1.0

Now we reset :

V_next  = (1 .- S_next) .* V_next

We need the calculation of V_next twice to ensure a proper reset (or we could just rename the first V_next to V_ or something).

@ChauhanT
Copy link

ChauhanT commented Nov 8, 2022

The above will probably not work :D We are not using the current updates to do the voltages. This will work better:

# First let the feed-forward input propagate
I_next = I .+ (dt/tau_syn) .* (-I .+ layer_input .+ rex_input)

# Next we do the pure voltage updates without the discontinuities
V_ = V .+ (dt/tau_mem) .* (-V .+ I_next)

# Spike calculation
S_next = V_ .>= 1.0

# Voltage updates with reset
V_next  = (1 .- S_next) .* V_

@wsphillips
Copy link
Author

@ChauhanT updated

@darsnack
Copy link

darsnack commented Nov 8, 2022

Here is a slightly modified variant:

using Flux
using ChainRuleCore
using Functors: @functor

function spike_threshold(v, vth)
    s = v .>= vth
    v = (1 .- s) .* v

    return s, v
end
spike_threshold(vth) = Base.Fix2(spike_threshold, vth)

struct LIFCell{T, F, R}
    forward::T
    activation::F
    dt::R
    tau_s::R
    tau_m::R
end

LIFCell(forward, activation = spike_threshold(1f0); dt = 1f0, tau_s = 10f0, tau_m = 10f0) =
    LIFCell(forward, activation, dt, tau_s, tau_m)

function apply_spikes!(m::Flux.Recur, s)
    m.state = s
    return m
end
apply_spikes!(m, s) = m
ChainRulesCore.@nondifferentiable apply_spikes!

function (lif::LIFCell)((S, V, I), x)
    Iin = lif.forward(x)
    I = I .+ (lif.dt / lif.tau_s) .* (Iin .- I)
    V = V .+ (lif.dt / lif.tau_m) .* (I - V)
    S, V = lif.activation(V)
    apply_spikes!(lif.forward, S)

    return (S, V, I), S
end

@functor LIFCell

LIF(args...; init_state, kwargs...) =
    Recur(LIFCell(args...; kwargs...), init_state)

# a fully connected example
input = rand(Bool, 10, 50)
Nin = size(input, 1)
Nout = 5
S, V, I = falses(Nout), zeros32(Nout), zeros32(Nout)
spiking_ff = LIF(Dense(Nin => Nout); init_state = (S, V, I))
spiking_ff(input)

# with recurrence
spiking_rnn = LIF(RNN(Nin => Nout, identity); init_state = (S, V, I))
spiking_rnn(input)

# a convolution example
input = rand(Bool, 32, 32, 3, 50)
conv = Conv((3, 3), 3 => 16)
Nin = size(input)
Nout = Flux.outputsize(conv, Nin)[1:3]
S, V, I = falses(Nout...), zeros32(Nout...), zeros32(Nout...)
spiking_conv = LIF(conv; init_state = (S, V, I))

What I changed:

  • LIFCell can wrap an arbitrary Flux layer. Other option would be to let x be the input "current" from a preceding layer like Conv. I think this is what the initial implementation was going for? Doing it this way lets RNN handle the recurrence, avoiding that computation if you don't want recurrence without code duplication.
    • If you prefer keeping the recurrence in LIFCell, then doing Ir = reshape(lif.Wr * MLUtils.flatten(S), size(S)) will hit BLAS.
  • The spiking behavior is abstracted out to an activation function which will allow rate-encoded networks to use the same layer.
  • Norse seems to handle this with separate LIFCell, LIFRecurrentCell, and LICell.

@darsnack
Copy link

darsnack commented Nov 8, 2022

Btw @avik-pal, I am confused in the Lux case why (S, V, I) is not st?

@wsphillips
Copy link
Author

LIFCell can wrap an arbitrary Flux layer. Other option would be to let x be the input "current" from a preceding layer like Conv. I think this is what the initial implementation was going for? Doing it this way lets RNN handle the recurrence, avoiding that computation if you don't want recurrence without code duplication.

Yes, this is better. Due to inexperience I don't (yet) fully grok the implicit connections that happen with the default layer types, but I think I get it now. So an RNN layer has an additional/independent weight matrix for the hidden state? I assumed it was element-wise unweighted feedback initially. And you're forcing the RNN layer into summing the previous layer input (x) with the LIF layer output (S) by manually mutating its hidden state at the end of each LIFCell evaluation (via apply_spikes!)?

I think the only thing left is to define a simple @scalar_rule for spike_threshold...

Using this gist as a template, we can probably write up other neuron model types pretty easily (e.g. Izhikevich, maybe AdEx and other IF variants).

I guess for STDP we would make DenseSTDP and RecurrentSTDP layers (consumed by LIFCell) that hold hidden states for the pre/post synaptic "traces" and then tag their parameters as frozen.

@ChauhanT
Copy link

ChauhanT commented Nov 9, 2022

So this is a pretty important design decision I think - do we want the synapses to stay with the post-synaptic layers or not. It really changes how you implement some plasticity rules. For gradient-based learning the usual flux/lux method works great because the synapses fit in very nicely as properties of the post-synaptic layer. I am open to seeing how implementing the usual mono-synaptic STDP (the first rule I would like to implement) works with this sort of setup. This gives us the added advantage of not having to maintain a separate class/structure for the weights.

However, I think if we go this route, we need to design a basic set of functions (in terms of the API) which the user has to implement for a 'custom' layer to work ootb. In C++ you'd use an abstract class to do this sort of stuff - not sure what the Julia equivalent is, so looking forward to developing some Julia chops along the way.

@darsnack
Copy link

darsnack commented Nov 9, 2022

Yes, this is better. Due to inexperience I don't (yet) fully grok the implicit connections that happen with the default layer types, but I think I get it now. So an RNN layer has an additional/independent weight matrix for the hidden state? I assumed it was element-wise unweighted feedback initially. And you're forcing the RNN layer into summing the previous layer input (x) with the LIF layer output (S) by manually mutating its hidden state at the end of each LIFCell evaluation (via apply_spikes!)?

Right, RNNCell is a function like (h, x) -> activation.(Wi * x .+ Wr * h .+ b). The Wi * x + Wr * h == Iin + Ir in our LIF model. I am manually hacking h to be the spikes via apply_spikes!, since RNN is sugar for Recur(RNNCell, h). Another implementation would be to have LIFCell wrap RNNCell directly. Then we can pass S in for h instead of hacking it. But then we want lif.forward(S, x), and the non-recurrent layers don't have state, so we'd need some kind of indirection for Dense or Conv. Recur is basically Flux's way of making non-recurrent and recurrent layers share the same input interface. It will be a design decision whether we want that or not.

I think the only thing left is to define a simple @scalar_rule for spike_threshold...

Yes this should be all that's needed for surrogate gradients.

I guess for STDP we would make DenseSTDP and RecurrentSTDP layers (consumed by LIFCell) that hold hidden states for the pre/post synaptic "traces" and then tag their parameters as frozen.

So this is a pretty important design decision I think - do we want the synapses to stay with the post-synaptic layers or not. It really changes how you implement some plasticity rules. For gradient-based learning the usual flux/lux method works great because the synapses fit in very nicely as properties of the post-synaptic layer. I am open to seeing how implementing the usual mono-synaptic STDP (the first rule I would like to implement) works with this sort of setup. This gives us the added advantage of not having to maintain a separate class/structure for the weights.

I think if we're going to build on an ML framework like Flux or Lux, it is best to think about how SNNs look in the usual interface. This would be the weights living with the post-synaptic neurons, and the updates being handled by a "gradient calculator" + optimizer. Coming from frameworks like Brian, we think about the network as an object that we set off at t = 0 and it self-updates according to some rules. I think if you are taking that dynamical systems abstraction for the design, you're better off with something like DiffEq.jl or an event-based simulator.

From an ML perspective, you would have: (1) let model compute output from input, (2) compute a weight update (typically via gradients), and (3) give the weight update to an optimizer to actually change the model. So for STDP this pseudocode looks like:

# ANN training
for (x, y) in data
    grad = gradient(m -> loss(m(x), y), model)
    update!(optimizer, model, grad)
end

# SNN training via surrogate gradients
for (x, y) in data
    grad = gradient(model) do m
        yp = [m(xt) for xt in make_spikes(x)]
        loss(yp, make_spikes(y))
    end
    update!(optimizer, model, grad)
end

# SNN training via STDP
for (x, y) in data
    for xt in make_spikes(x)
        z = Flux.activations(model, xt) # get intermediate spikes
        pseudo_grad = stdp(model, (xt, z[1:(end - 1)]...), z[2:end])
        update!(optimizer, model, pseudo_grad)
    end
end

Essentially the details of Steps (1-3) are called part of the "training loop" in ML. There exist libraries (e.g. FluxTraining.jl) for defining training loops without manually writing them out. We could define the rules for surrogate gradient training or STDP training as extensions to FluxTraining.jl. Doing things this way will result in better plug-and-play with the existing ML ecosystem. For example, you could log your spike trains using built-in logging callbacks instead of implementing a custom "monitor" or "recorder."

@ChauhanT
Copy link

ChauhanT commented Nov 9, 2022

Love it @darsnack ! It's a fantastic approach to treat the stdp weight-changes as gradients, and leverage the existing Flux.jl framework. Two questions:

  1. Is there a strict API for grad ? Or is it simply an n-d array with all the gradients ?
  2. Can the optimizer implement any mapping f : w_{t-1} x dw_t --> w_t ? Otherwise we might need to write one which does. Hopefully the optimisers don't have a lot of crazy API.

I've got some ideas for implementing STDP which can make things a bit easier/faster - let us discuss that part later, once we have set up a working, efficient, spiking layer.

@darsnack
Copy link

darsnack commented Nov 9, 2022

  1. For grad to be passed to the optimizer, it has to be in a specific form. I would suggest the form Optimisers.jl expects which is a tree of parameters with the same shape as the model. For example, Chain(Dense(10 => 5), Dense(5 => 2)) has a tree that looks like
    (layers = (
      (weight = ..., bias = ..., sigma = ...),
      (weight = ..., bias = ..., sigma = ...)
    ),)
    where ... is the parameter array or the gradient array. AD packages like Zygote will automatically generate a gradients in tree of this form. It is not hard to build such trees. Usually you'll define something like stdp at the array level then map it over the model using a combination of multiple dispatch and Functors.jl. It sounds confusing at first, but it will feel natural after getting used to Julia.
  2. Yes, to implement a custom optimizer that adheres to Optimisers.jl's interface, you need to implement apply(opt::MyOpt, state, x, dx) where state is the internal optimizer state (can be empty), x is a plain array, and dx is also a plain array. Optimisers.jl will handle applying this simple rule over a complex structured model and gradient for you. Though I think in our context, we can just use Descent which applies the rule x -= eta * dx where eta is the learning rate.

You can of course choose not to use Optimisers.jl at all as well, and manually loop over the model and update each array. Then there is no contract for what grad needs to look like. But the reason I wouldn't do this for a package is that using ML optimizers can be useful for accelerating training. This just makes long running simulations faster. For example, I've used SGD + momentum in the past for biological networks even though it isn't a biologically plausible update (depending on the paper, this approximation is excusable).

@ChauhanT
Copy link

Thank you for the detailed outline of this @darsnack ! I want to go back to the idea of wrapping RNNCell. The Recur block wraps both Linear and Conv Layers ? Why do you think it would be different if we were to use spiking instead of the normal layers ?

I'm also thinking a spiking layer in a general form will need to be defined like:

  1. parameters P : e.g., for a non-adaptive LIF this would be: P = {dt , tau_m , tau_s , threshold}
  2. state Z : e.g., for a LIF this would be: Z = {I,V,S} where S is the spike indicator variable
  3. output O: e.g., for a LIF (or most spiking networks), this will simply be some jump-y variable like S, but it could be any function of Z

Now, given the output of the previous layer (X = O for previous layer) the forward pass is expressed as:

Z_t,O_t = forward ( Z_{t-1}, X_{t}, P )

So what makes this different from the already implemented recurrent layers in Flux.jl ? The fact that forward(.) will have discontinuities.

For starting out, we assume that all discontinuities in our layers occur when we extract O from Z. For LIFs, and most layers we will implement (or realistically need within an ML/bio-AI framework), the assumption is much stronger, because the discontinuity is a simple thresholding operation - allowing us to use simple forms of rules like super-spike or surrogate gradients.

This, is not implemented in Flux.jl. We need to create the proper hooks/infrastructure for doing forward passes (discontinuities are passed on), and backward propagating gradients (discontinuities need to be dealt with).

This is a network that will learn using the usual backprop and rules already implemented in Lux/Flux. Now comes the question of introducing STDP and other non-gradient based rules. I'm still thinking through this part - especially with the Opitimiser.jl trick @darsnack has proposed. But in the meanwhile, I would like to start work on getting the rest of the library up! I'll also need some help/guidance regarding how a repo is managed :D I have only made private repos to keep track of my code. I have no CI/CD experience. Happy to learn as we move forward!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment