Skip to content

Instantly share code, notes, and snippets.

@braun-steven
Created April 3, 2020 16:29
Show Gist options
  • Save braun-steven/ceb899a64630cb1473e84986b0bfb3b5 to your computer and use it in GitHub Desktop.
Save braun-steven/ceb899a64630cb1473e84986b0bfb3b5 to your computer and use it in GitHub Desktop.
Layer-wise SPN Example Usage: Forward, Sampling, Conditional Sampling
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.datasets import make_blobs
from torch import nn
from tqdm import trange
from spn.algorithms.layerwise.layers import Product, Sum
from spn.algorithms.layerwise.utils import provide_evidence
from spn.experiments.RandomSPNs_layerwise.distributions import RatNormal
if __name__ == "__main__":
class LayerSpn(nn.Module):
def __init__(self):
super().__init__()
# Normal leaf layer, output shape: [N=?, D=2, C=5, R=1]
self.leaf = RatNormal(in_features=2, out_channels=2)
# Product layer, output shape: [N=?, D=1, C=5, R=1]
self.p = Product(in_features=2, cardinality=2)
# Sum layer, root node, output shape: [N=?, D=1, C=1, R=1]
self.s = Sum(in_channels=2, in_features=1, out_channels=1)
def forward(self, x):
# Forward bottom up
x = self.leaf(x)
x = self.p(x)
x = self.s(x)
return x
def sample(self, n=100):
# Sample top down
ctx = self.s.sample(n=n)
ctx = self.p.sample(context=ctx)
samples = self.leaf.sample(context=ctx)
return samples
# Generate two gaussian blobs
n_labels = 2
n_samples = 500
data, y = make_blobs(
n_samples=n_samples, centers=n_labels, n_features=2, random_state=0, center_box=(-15, 15), cluster_std=0.5
)
data = torch.from_numpy(data).float()
# Plot the original data
plt.figure()
plt.subplot(2, 2, 1)
plt.title("Original training data")
for i in range(n_labels):
plt.scatter(*data[y == i].T, label=f"Blob {i}", alpha=0.7)
plt.legend()
plt.xlabel("$x_0$")
plt.ylabel("$x_1$")
xlim, ylim = plt.xlim(), plt.ylim()
# Create SPN model
spn = LayerSpn()
# Use SGD
optimizer = torch.optim.SGD(spn.parameters(), lr=0.5, weight_decay=0.0)
batch_size = 250
n_epochs = 1000
with trange(n_epochs) as epoch_iter:
for epoch in epoch_iter:
running_loss = 0.0
for batch_idx in np.arange(data.shape[0], step=batch_size):
batch = data[batch_idx : batch_idx + batch_size]
# Reset gradients
optimizer.zero_grad()
# Inference
output = spn(batch)
# Comput loss
loss = -1 * output.mean()
# Backprop
loss.backward()
optimizer.step()
# Collect loss
running_loss += loss.item()
epoch_iter.set_description(f"Loss: {running_loss/(data.shape[0] // batch_size):<3.4f}")
with torch.no_grad():
# Sample unconditionally
samples = spn.sample(n=1000)
plt.subplot(2, 2, 2)
plt.xlabel("$x_0$")
plt.ylabel("$x_1$")
plt.title("Unconditioned Samples")
plt.scatter(*samples.T, alpha=0.7, c="black")
plt.xlim(xlim)
plt.ylim(ylim)
# Sample, conditioned on x_0
plt.subplot(2, 2, 3)
plt.xlabel("$x_0$")
plt.ylabel("$x_1$")
plt.title("Conditioned Samples (on $x_0$)")
for i in range(n_labels):
data_i = data[y == i]
data_i[:, 1] = float("nan")
with provide_evidence(spn, data_i):
samples = spn.sample(n=data_i.shape[0])
plt.scatter(*samples.T, label=f"Blob {i}", alpha=0.7)
plt.xlim(xlim)
plt.ylim(ylim)
plt.legend()
# Sample, conditioned on x_1
plt.subplot(2, 2, 4)
plt.xlabel("$x_0$")
plt.ylabel("$x_1$")
plt.title("Conditioned Samples (on $x_1$)")
for i in range(n_labels):
data_i = data[y == i]
data_i[:, 0] = float("nan")
with provide_evidence(spn, data_i):
samples = spn.sample(n=data_i.shape[0])
plt.scatter(*samples.T, label=f"Blob {i}", alpha=0.7)
plt.xlim(xlim)
plt.ylim(ylim)
plt.legend()
plt.tight_layout()
plt.savefig("sampling-result.png", dpi=180)
@braun-steven
Copy link
Author

braun-steven commented Apr 3, 2020

sampling-result

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment