Skip to content

Instantly share code, notes, and snippets.

@FrancescoSaverioZuppichini
Created September 23, 2018 16:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FrancescoSaverioZuppichini/6a4db72b27d5d6aa2519ef9917b3c821 to your computer and use it in GitHub Desktop.
Save FrancescoSaverioZuppichini/6a4db72b27d5d6aa2519ef9917b3c821 to your computer and use it in GitHub Desktop.
import torch.nn.functional as F
class MyCNNClassifier(nn.Module):
def __init__(self, in_c, n_classes):
super().__init__()
self.conv1 = nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(32)
self.fc1 = nn.Linear(32 * 28 * 28, 1024)
self.fc2 = nn.Linear(1024, n_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = x.view(x.size(0), -1) # flat
x = self.fc1(x)
x = F.sigmoid(x)
x = self.fc2(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment