Skip to content

Instantly share code, notes, and snippets.

@zubair-irshad
Last active February 19, 2020 01:19
Show Gist options
  • Save zubair-irshad/a4184bae7f748d1e9b4be28878f6f1d0 to your computer and use it in GitHub Desktop.
Save zubair-irshad/a4184bae7f748d1e9b4be28878f6f1d0 to your computer and use it in GitHub Desktop.
class CNN(nn.Module):
def __init__(self, im_size, hidden_dim,hidden_dim2,hidden_dim3, kernel_size, n_classes):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3,16,3,padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16,32,3,padding=1)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32,64,3,padding=1)
self.bn3 = nn.BatchNorm2d(64)
self.conv4 = nn.Conv2d(64,128,3,padding=1)
self.bn4 = nn.BatchNorm2d(128)
self.pool = nn.MaxPool2d(2,2)
(C,N,H) = im_size
hout_size = H/(2**4)
self.fc1 = nn.Linear(hout_size*hout_size*128,hidden_dim)
self.fc2 = nn.Linear(hidden_dim,hidden_dim2)
self.fc3 = nn.Linear(hidden_dim2,n_classes)
self.dropout1 = nn.Dropout(p=0.25)
self.dropout2 = nn.Dropout(p=0.2)
def forward(self, images):
x = self.pool(F.relu(self.bn1(self.conv1(images))))
x = self.pool(F.relu(self.bn2(self.conv2(x))))
x = self.pool(F.relu(self.bn3(self.conv3(x))))
x = self.pool(F.relu(self.bn4(self.conv4(x))))
#Flatten the output to a vector
x = x.view(x.shape[0],-1)
x = self.dropout1(x)
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
x = self.fc3(x)
scores = x
return scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment