Skip to content

Instantly share code, notes, and snippets.

@mstimberg
Last active January 18, 2022 16:13
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mstimberg/30c64429503400de40e7 to your computer and use it in GitHub Desktop.
Save mstimberg/30c64429503400de40e7 to your computer and use it in GitHub Desktop.
Example of implementing a buffer for a delayed variable in Brian2
# Tested with Brian 2.1.3.1
from brian2 import *
G = NeuronGroup(2, '''drate/dt = amplitude*sin(2*pi*100*Hz*t)/(10*ms) : Hz
rate_delayed : Hz
amplitude : Hz
buffer_pointer : integer (shared)
delay_steps : integer''', method='euler')
G.amplitude = [50, 100] * Hz
G.delay_steps = [6, 24] # delay in time steps per neuron
buffer_size = 25 # 1+Maximum delay (in time steps)
# This is a low-level way of adding a state variable that normally should not be used but it is currently
# the only way to add a two-dimensional state variable.
G.variables.add_array('rate_buffer', dimensions=Hz.dim, size=(buffer_size, len(G)))
# At every time step, we update the buffer and store a delayed version of the rate (potentially using
# different delays for different neurons) in the rate_delayed state variable.
update_code = '''buffer_pointer = (buffer_pointer + 1) % buffer_size
rate_delayed = update_rate_buffer(rate, rate_buffer, buffer_pointer, delay_steps, buffer_size)'''
# Note that the function below is only defined for Python, so we have to make sure that the update code runs in
# Python even if the state update code runs in weave
buffer_updater = G.run_regularly(update_code, codeobj_class=NumpyCodeObject)
# This function does the whole update/retrieve work for the buffer. We write the current rates at the row given by the
# current value of buffer_pointer (which is incremented every time step) and retrieve the rates delay_steps before this
# row (using a modulo operator to wrap around at the end of the array). Unfortunately, C++ code generation currently
# cannot deal with 2d arrays, therefore we can only do this in Python
@check_units(rate=Hz, rate_buffer=Hz, buffer_pointer=1, delay_steps=1, buffer_size=1, result=Hz)
def update_rate_buffer(rate, rate_buffer, buffer_pointer, delay_steps, buffer_size):
# Write current rate into the buffer
rate_buffer[buffer_pointer, :] = rate
# Get delayed rates
rows = (buffer_pointer - delay_steps) % buffer_size
return rate_buffer[rows, arange(len(rows))]
mon = StateMonitor(G, ['rate', 'rate_delayed'], record=True)
run(50*ms, report='text')
# Plot the two rates and their delayed versions (which could be fed as input
# into other cells)
plt.plot(mon.t/ms, mon.rate.T/Hz)
plt.plot(mon.t/ms, mon.rate_delayed.T/Hz, ':')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment