Skip to content

Instantly share code, notes, and snippets.

@neworderofjamie
Created March 16, 2022 12:58
Show Gist options
  • Save neworderofjamie/2a276e5386fd6b17893ad31965d0dbe7 to your computer and use it in GitHub Desktop.
Save neworderofjamie/2a276e5386fd6b17893ad31965d0dbe7 to your computer and use it in GitHub Desktop.
import numpy as np
from os import path
from ml_genn import InputLayer, Layer, SequentialModel
from ml_genn.compilers import Compiler
from ml_genn.neurons import IntegrateFire, IntegrateFireInput
from ml_genn.connectivity import Dense
# Load weights
weights = []
while True:
filename = "weights_%u_%u.npy" % (len(weights), len(weights) + 1)
if path.exists(filename):
weights.append(np.load(filename))
else:
break
# Create sequential model
model = SequentialModel()
with model:
input = InputLayer(IntegrateFireInput(v_thresh=5.0), 784)
for w in weights:
Layer(Dense(weight=w), IntegrateFire(v_thresh=5.0))
compiler = Compiler(dt=1.0)
compiled_model = compiler.compile(model, "simple_mnist")
# Load testing data
testing_images = np.load("testing_images.npy")
testing_labels = np.load("testing_labels.npy")
with compiled_model:
# Loop through testing images
num_correct = 0
for img, lab in zip(testing_images, testing_labels):
# **TODO** handle weak ref
compiled_model.reset_trial()
compiled_model.set_input({input: img * 0.01})
for t in range(100):
compiled_model.step_time()
output = compiled_model.get_output(model.layers[-1])
if np.argmax(output) == lab:
num_correct += 1
print(f"Accuracy {(num_correct / float(testing_images.shape[0])) * 100.0}%")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment