Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Created September 25, 2023 14:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save llandsmeer/67cd64855895544861dace89b939f99a to your computer and use it in GitHub Desktop.
Save llandsmeer/67cd64855895544861dace89b939f99a to your computer and use it in GitHub Desktop.
Minimal norse learning example - showing one possible (definitely not the most efficient) way of learning the XOR problem in norse. Created while following the norse workshop during HBPSC23
import torch
import numpy as np
import norse.torch as norse
import matplotlib.pyplot as plt
import tqdm.notebook as tqdm
_ = torch.manual_seed(0)
class MyModule(torch.nn.Module):
def __init__(self, p=norse.LIFParameters()):
super().__init__()
p1 = norse.LIFParameters(
tau_mem_inv=torch.nn.Parameter(
torch.full((3,), torch.as_tensor(1.0 / 1e-2))
)
)
p2 = norse.LIFParameters(
tau_mem_inv=torch.nn.Parameter(
torch.full((5,), torch.as_tensor(1.0 / 1e-2))
),
)
self.model = norse.SequentialState(
torch.nn.Linear(3, 8, bias=False),
norse.LIFCell(),
torch.nn.Linear(8, 3, bias=False),
norse.LIFCell(),
torch.nn.Linear(3, 2, bias=False),
norse.LICell(norse.LIParameters(tau_mem_inv=torch.as_tensor(1/0.1))),
torch.nn.Softmax(),
)
for l in self.model.children():
if isinstance(l, torch.nn.Linear):
w = l.get_parameter('weight')
torch.nn.init.uniform_(w)
def forward(self, x, state):
return self.model(x, state)
plt.ion()
def train(model, xs, ys):
losses = []
for x, y in tqdm.tqdm(zip(xs, ys)):
state = None
trace = []
l = 0
#x = (x * (0.5 + 0.5 * torch.randn_like(x)) + (1 - (1-x) * (0.5 + 0.5 * torch.randn_like(x))))/2
for t in range(100):
out, state = model(x, state)
trace.append(out.detach().numpy())
if t > 50:
loss = torch.nn.functional.mse_loss(out, y)
l = l + loss
trace = np.array(trace)
plt.clf()
plt.plot(trace[:,:,0])
plt.ylim([0, 1])
plt.pause(0.01)
#loss = torch.nn.functional.mse_loss(out, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
losses.append(loss.detach())
if not plt.get_fignums():
break
return losses
nsamples = 1000
data = torch.as_tensor([
[
[1, 0., 0.],
[1, 1., 0.],
[1, 0., 1.],
[1, 1., 1.],
]
]*nsamples)
labels = torch.as_tensor([
[
[0., 1.0], [1., 0.0], [1., 0.0], [0., 1.0]
]
]*nsamples,
)
print(data.shape, labels.shape)
model = MyModule()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
m1_losses = train(model, data, labels)
plt.clf()
plt.plot(m1_losses)
plt.ioff()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment