Skip to content

Instantly share code, notes, and snippets.

@amaarora
Created January 1, 2019 05:59
Show Gist options
  • Save amaarora/34d094f14c5af2645ac90449f5a43237 to your computer and use it in GitHub Desktop.
Save amaarora/34d094f14c5af2645ac90449f5a43237 to your computer and use it in GitHub Desktop.
def get_weights(*dims): return nn.Parameter(torch.randn(dims)/dims[0])
def softmax(x): return torch.exp(x)/(torch.exp(x).sum(dim=1)[:,None])
class LogReg(nn.Module):
def __init__(self):
super().__init__()
self.l1_w = get_weights(28*28, 10) # Layer 1 weights
self.l1_b = get_weights(10) # Layer 1 bias
def forward(self, x):
x = x.view(x.size(0), -1)
x = (x @ self.l1_w) + self.l1_b # Linear Layer
x = torch.log(softmax(x)) # Non-linear (LogSoftmax) Layer
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment