Created
September 25, 2023 14:25
-
-
Save llandsmeer/67cd64855895544861dace89b939f99a to your computer and use it in GitHub Desktop.
Minimal norse learning example - showing one possible (definitely not the most efficient) way of learning the XOR problem in norse. Created while following the norse workshop during HBPSC23
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import norse.torch as norse | |
import matplotlib.pyplot as plt | |
import tqdm.notebook as tqdm | |
_ = torch.manual_seed(0) | |
class MyModule(torch.nn.Module): | |
def __init__(self, p=norse.LIFParameters()): | |
super().__init__() | |
p1 = norse.LIFParameters( | |
tau_mem_inv=torch.nn.Parameter( | |
torch.full((3,), torch.as_tensor(1.0 / 1e-2)) | |
) | |
) | |
p2 = norse.LIFParameters( | |
tau_mem_inv=torch.nn.Parameter( | |
torch.full((5,), torch.as_tensor(1.0 / 1e-2)) | |
), | |
) | |
self.model = norse.SequentialState( | |
torch.nn.Linear(3, 8, bias=False), | |
norse.LIFCell(), | |
torch.nn.Linear(8, 3, bias=False), | |
norse.LIFCell(), | |
torch.nn.Linear(3, 2, bias=False), | |
norse.LICell(norse.LIParameters(tau_mem_inv=torch.as_tensor(1/0.1))), | |
torch.nn.Softmax(), | |
) | |
for l in self.model.children(): | |
if isinstance(l, torch.nn.Linear): | |
w = l.get_parameter('weight') | |
torch.nn.init.uniform_(w) | |
def forward(self, x, state): | |
return self.model(x, state) | |
plt.ion() | |
def train(model, xs, ys): | |
losses = [] | |
for x, y in tqdm.tqdm(zip(xs, ys)): | |
state = None | |
trace = [] | |
l = 0 | |
#x = (x * (0.5 + 0.5 * torch.randn_like(x)) + (1 - (1-x) * (0.5 + 0.5 * torch.randn_like(x))))/2 | |
for t in range(100): | |
out, state = model(x, state) | |
trace.append(out.detach().numpy()) | |
if t > 50: | |
loss = torch.nn.functional.mse_loss(out, y) | |
l = l + loss | |
trace = np.array(trace) | |
plt.clf() | |
plt.plot(trace[:,:,0]) | |
plt.ylim([0, 1]) | |
plt.pause(0.01) | |
#loss = torch.nn.functional.mse_loss(out, y) | |
optimizer.zero_grad() | |
l.backward() | |
optimizer.step() | |
losses.append(loss.detach()) | |
if not plt.get_fignums(): | |
break | |
return losses | |
nsamples = 1000 | |
data = torch.as_tensor([ | |
[ | |
[1, 0., 0.], | |
[1, 1., 0.], | |
[1, 0., 1.], | |
[1, 1., 1.], | |
] | |
]*nsamples) | |
labels = torch.as_tensor([ | |
[ | |
[0., 1.0], [1., 0.0], [1., 0.0], [0., 1.0] | |
] | |
]*nsamples, | |
) | |
print(data.shape, labels.shape) | |
model = MyModule() | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |
m1_losses = train(model, data, labels) | |
plt.clf() | |
plt.plot(m1_losses) | |
plt.ioff() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment