Skip to content

Instantly share code, notes, and snippets.

@dsalaj
Last active December 9, 2019 08:26
Show Gist options
  • Save dsalaj/dc18edc82053df35087f2ab3026f61e7 to your computer and use it in GitHub Desktop.
Save dsalaj/dc18edc82053df35087f2ab3026f61e7 to your computer and use it in GitHub Desktop.
Crossing Threshold Encoding of pixel values to spikes
def find_onset_offset(y, threshold):
"""
Given the input signal `y` with samples,
find the indices where `y` increases and descreases through the value `threshold`.
Return stacked binary arrays of shape `y` indicating onset and offset threshold crossings.
`y` must be 1-D numpy arrays.
"""
if threshold == 1:
equal = y == threshold
transition_touch = np.where(equal)[0]
touch_spikes = np.zeros_like(y)
touch_spikes[transition_touch] = 1
return np.expand_dims(touch_spikes, axis=0)
else:
# Find where y crosses the threshold (increasing).
lower = y < threshold
higher = y >= threshold
transition_onset = np.where(lower[:-1] & higher[1:])[0]
transition_offset = np.where(higher[:-1] & lower[1:])[0]
onset_spikes = np.zeros_like(y)
offset_spikes = np.zeros_like(y)
onset_spikes[transition_onset] = 1
offset_spikes[transition_offset] = 1
return np.stack((onset_spikes, offset_spikes))
def get_data_dict(batch_size, type='train'):
'''
Generate the dictionary to be fed when running a tensorflow op.
:param batch_size:
:param test:
:return:
'''
if type == 'test':
input_px, target_oh = mnist.test.next_batch(batch_size, shuffle=False)
elif type == 'validation':
input_px, target_oh = mnist.validation.next_batch(batch_size)
elif type == 'train':
input_px, target_oh = mnist.train.next_batch(batch_size)
else:
raise ValueError("Wrong data group: " + str(type))
target_num = np.argmax(target_oh, axis=1)
if FLAGS.n_repeat > 1:
input_px = np.repeat(input_px, FLAGS.n_repeat, axis=1)
if FLAGS.crs_thr:
# GENERATE THRESHOLD CROSSING SPIKES
thrs = np.linspace(0, 1, FLAGS.n_in // 2) # number of input neurons determins the resolution
spike_stack = []
for img in input_px: # shape img = (784)
Sspikes = None
for thr in thrs:
if Sspikes is not None:
Sspikes = np.concatenate((Sspikes, find_onset_offset(img, thr)))
else:
Sspikes = find_onset_offset(img, thr)
Sspikes = np.array(Sspikes) # shape Sspikes = (31, 784)
Sspikes = np.swapaxes(Sspikes, 0, 1)
spike_stack.append(Sspikes)
spike_stack = np.array(spike_stack)
# add output cue neuron, and expand time for two image rows (2*28)
out_cue_duration = 2 * 28 * FLAGS.n_repeat
spike_stack = np.lib.pad(spike_stack, ((0, 0), (0, out_cue_duration), (0, 1)), 'constant')
# output cue neuron fires constantly for these additional recall steps
spike_stack[:, -out_cue_duration:, -1] = 1
else:
spike_stack = input_px
spike_stack = np.expand_dims(spike_stack, axis=2)
# # match input dimensionality (add inactive output cue neuron)
# spike_stack = np.lib.pad(spike_stack, ((0, 0), (0, 0), (0, 1)), 'constant')
# transform target one hot from batch x classes to batch x time x classes
data_dict = {input_spikes: spike_stack, targets: target_num}
return data_dict, input_px
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment