Last active
January 18, 2022 16:13
-
-
Save mstimberg/30c64429503400de40e7 to your computer and use it in GitHub Desktop.
Example of implementing a buffer for a delayed variable in Brian2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Tested with Brian 2.1.3.1 | |
from brian2 import * | |
G = NeuronGroup(2, '''drate/dt = amplitude*sin(2*pi*100*Hz*t)/(10*ms) : Hz | |
rate_delayed : Hz | |
amplitude : Hz | |
buffer_pointer : integer (shared) | |
delay_steps : integer''', method='euler') | |
G.amplitude = [50, 100] * Hz | |
G.delay_steps = [6, 24] # delay in time steps per neuron | |
buffer_size = 25 # 1+Maximum delay (in time steps) | |
# This is a low-level way of adding a state variable that normally should not be used but it is currently | |
# the only way to add a two-dimensional state variable. | |
G.variables.add_array('rate_buffer', dimensions=Hz.dim, size=(buffer_size, len(G))) | |
# At every time step, we update the buffer and store a delayed version of the rate (potentially using | |
# different delays for different neurons) in the rate_delayed state variable. | |
update_code = '''buffer_pointer = (buffer_pointer + 1) % buffer_size | |
rate_delayed = update_rate_buffer(rate, rate_buffer, buffer_pointer, delay_steps, buffer_size)''' | |
# Note that the function below is only defined for Python, so we have to make sure that the update code runs in | |
# Python even if the state update code runs in weave | |
buffer_updater = G.run_regularly(update_code, codeobj_class=NumpyCodeObject) | |
# This function does the whole update/retrieve work for the buffer. We write the current rates at the row given by the | |
# current value of buffer_pointer (which is incremented every time step) and retrieve the rates delay_steps before this | |
# row (using a modulo operator to wrap around at the end of the array). Unfortunately, C++ code generation currently | |
# cannot deal with 2d arrays, therefore we can only do this in Python | |
@check_units(rate=Hz, rate_buffer=Hz, buffer_pointer=1, delay_steps=1, buffer_size=1, result=Hz) | |
def update_rate_buffer(rate, rate_buffer, buffer_pointer, delay_steps, buffer_size): | |
# Write current rate into the buffer | |
rate_buffer[buffer_pointer, :] = rate | |
# Get delayed rates | |
rows = (buffer_pointer - delay_steps) % buffer_size | |
return rate_buffer[rows, arange(len(rows))] | |
mon = StateMonitor(G, ['rate', 'rate_delayed'], record=True) | |
run(50*ms, report='text') | |
# Plot the two rates and their delayed versions (which could be fed as input | |
# into other cells) | |
plt.plot(mon.t/ms, mon.rate.T/Hz) | |
plt.plot(mon.t/ms, mon.rate_delayed.T/Hz, ':') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment