Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Created November 7, 2020 10:53
Show Gist options
  • Save seanbenhur/909f5ac77489c62042479a7c304bdb5f to your computer and use it in GitHub Desktop.
Save seanbenhur/909f5ac77489c62042479a7c304bdb5f to your computer and use it in GitHub Desktop.
class Lenet(nn.Module):
def __init__(self):
super(Lenet,self).__init__()
self.tanh = nn.Tanh()
self.pool = nn.AvgPool2d(kernel_size=(2,2),stride=(2,2))
self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),stride=(1,1))
self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=(5,5),stride=(1,1))
self.conv3 = nn.Conv2d(in_channels=16,out_channels=120,kernel_size=(5,5),stride=(1,1))
self.linear1 = nn.Linear(120,84)
self.linear2 = nn.Linear(84,10)
def forward(self,x):
x = self.tanh(self.conv1(x))
x = self.pool(x)
x = self.tanh(self.conv2(x))
x = self.pool(x)
x = self.tanh(self.conv3(x))
x = x.reshape(x.shape[0],-1)
x = self.tanh(self.linear1(x))
x = self.linear2(x)
return x
model = Lenet()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment