Skip to content

Instantly share code, notes, and snippets.

@mstimberg
Last active July 21, 2022 13:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mstimberg/ebdd51a15130ddd0e61d0559637269ba to your computer and use it in GitHub Desktop.
Save mstimberg/ebdd51a15130ddd0e61d0559637269ba to your computer and use it in GitHub Desktop.
Changing a stimulus in an ongoing simulation in Brian's C++ standalone mode
from brian2 import *
set_device('cpp_standalone')
# Each "input" is one combination of rates across the 5 neurons
n_stimuli = 10
rates = TimedArray(np.random.rand(n_stimuli, 5)*100*Hz, dt=100*ms)
stimulus_length = 100*ms
max_repetitions = 5
input_spikes = NeuronGroup(5, '''
rate = stimulus_strength * rates(t - start_t + stimulus_length*stimulus_idx, i) : second**-1
start_t : second # Start time of the current trial
stimulus_idx : integer # Index of the stimulus to show
stimulus_strength : 1 # Factor to scale up stimulus
repetitions : integer # Number of times the stim has been presented''',
threshold='rand()<rate*dt')
input_spikes.stimulus_strength = 1
tau = 10*ms; tau_syn = 5*ms
target_neurons = NeuronGroup(5, '''dv/dt = (-v + I)/tau : 1
dI/dt = -I/tau_syn : 1 # synaptic current''',
threshold='v > 1', reset='v = 0')
connections = Synapses(input_spikes, target_neurons, on_pre='I += 1')
connections.connect(j='i')
# A dummy "synapse" and a dummy group. Used to keep track of the total number of spikes in
# the target group
spike_counter = NeuronGroup(1, 'spike_counter : integer')
counter_synapse = Synapses(target_neurons, spike_counter, on_pre='spike_counter += 1')
counter_synapse.connect()
# Reset the spike counter every 100ms
counter_synapse.run_regularly('spike_counter = 0', when='end', dt=stimulus_length)
# Another "synapse", that uses the information in the spike counter to update the stimulus
# and or
stimulus_control = Synapses(spike_counter, input_spikes)
stimulus_control.connect()
stimulus_control.run_regularly('''
switch_stimulus = spike_counter_pre > 5 or repetitions_post >= max_repetitions
repetitions_post = int(not switch_stimulus)*(repetitions_post + 1)
stimulus_idx_post += int(switch_stimulus)
stimulus_strength_post = 1*int(switch_stimulus) + 1.5*stimulus_strength_post*int(not switch_stimulus)
start_t = t
stimulus_strength_post = int(not stimulus_idx_post >= n_stimuli)*stimulus_strength_post
''', dt=stimulus_length, when='before_end')
in_mon = StateMonitor(input_spikes, ['stimulus_idx', 'stimulus_strength'], record=0)
in_spikes = SpikeMonitor(input_spikes)
out_spikes = SpikeMonitor(target_neurons)
state_mon = StateMonitor(target_neurons, 'v', record=True)
run(max_repetitions*10*stimulus_length)
fix, axs = plt.subplots(4, 1, sharex=True)
axs[0].plot(in_spikes.t/ms, in_spikes.i, '.')
axs[0].set(ylabel='input spikes')
axs[1].plot(in_mon.t/ms, in_mon.stimulus_idx[0], label='stim index')
axs[1].set(ylabel='stimulus index')
axs[1].yaxis.label.set_color('C0')
twin_ax = axs[1].twinx()
twin_ax.plot(in_mon.t/ms, in_mon.stimulus_strength[0], label='stim strength', color='C1')
twin_ax.set(ylabel='stimulus strength')
twin_ax.yaxis.label.set_color('C1')
axs[2].plot(state_mon.t/ms, state_mon.v.T, color='gray')
axs[2].set(ylabel='V')
axs[3].plot(out_spikes.t/ms, out_spikes.i, '.')
axs[3].set(xlabel='time (ms)', ylabel='output spikes')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment