Skip to content

Instantly share code, notes, and snippets.

@lgray
Last active March 5, 2020 14:51
Show Gist options
  • Save lgray/ea9b7f2fca609574b08a7f480ebfd105 to your computer and use it in GitHub Desktop.
Save lgray/ea9b7f2fca609574b08a7f480ebfd105 to your computer and use it in GitHub Desktop.
# here's a dynamic reduction network that can categorize
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.drn = DynamicReductionNetwork(input_dim=3, hidden_dim=64,
k = 16,
output_dim=2, aggr='add',
norm=torch.tensor([1., 1./27., 1./27.]))
def forward(self, data):
logits = self.drn(data)
return F.log_softmax(logits, dim=1)
# here's the training setup
model = Net().to(device)
for datum in data:
datum = datum.to(device)
result = model(datum)
F.nll_loss(result, datum).backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment