Skip to content

Instantly share code, notes, and snippets.

@proger
Last active April 22, 2024 13:36
Show Gist options
  • Save proger/8f68bbbef82dc51986039cd1ebc8baee to your computer and use it in GitHub Desktop.
Save proger/8f68bbbef82dc51986039cd1ebc8baee to your computer and use it in GitHub Desktop.
tensor network that can learn xor
#%%
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]).float()
y = torch.logical_xor(X[:, 0], X[:, 1]).float()
# http://outlace.com/TensorNets1.html
class TensorNetwork(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Parameter(torch.rand(2, 2))
self.b = nn.Parameter(torch.rand(2, 2))
self.output = nn.Parameter(torch.rand(1, 2))
def forward(self, x):
# violate no-cloning:
xa = F.linear(x, self.a)
xb = F.linear(x, self.b)
# collapse:
y = F.linear(xa * xb, self.output)
## same but with einsum:
# xa = torch.einsum('bh,hk->bk', (x, self.a))
# xb = torch.einsum('bh,hk->bk', (x, self.b))
# y = torch.einsum('bk,bk,ko->bo', (xa, xb, self.output))
return y
class CopyMul(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Parameter(torch.rand(2, 2))
self.output = nn.Parameter(torch.rand(1, 2))
def forward(self, x):
xa = F.linear(x, self.a)
y = F.linear(xa * xa, self.output)
return y
model = TensorNetwork()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
@torch.inference_mode()
def plot_decision_boundary(model, X, y):
x = torch.arange(-0.1, 1+0.1, 0.1)
XX, YY = torch.meshgrid(x, x)
data = torch.hstack((XX.ravel().reshape(-1,1), YY.ravel().reshape(-1,1)))
with torch.no_grad():
out = model(data)
Z = out.view(XX.shape)
plt.figure(figsize=(3,3))
plt.contourf(XX, YY, Z > 0.5, levels=1, alpha=0.5)
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', s=50)
plt.show()
for epoch in range(200000):
output = model(X).reshape(-1)
loss = F.mse_loss(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch+1) % 3000 == 0:
print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')
#print([(n, p.grad) for n, p in model.named_parameters() if p.grad is not None])
plot_decision_boundary(model, X, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment