Skip to content

Instantly share code, notes, and snippets.

@aletheia
Created July 14, 2020 08:21
Show Gist options
  • Save aletheia/fc4f43683e377fa82efd77639c7142b5 to your computer and use it in GitHub Desktop.
Save aletheia/fc4f43683e377fa82efd77639c7142b5 to your computer and use it in GitHub Desktop.
def __init__(self, train_data_dir,batch_size=128,test_data_dir=None, num_workers=4):
'''Constructor method
Parameters:
train_data_dir (string): path of training dataset to be used either for training and validation
batch_size (int): number of images per batch. Defaults to 128.
test_data_dir (string): path of testing dataset to be used after training. Optional.
num_workers (int): number of processes used by data loader. Defaults to 4.
'''
# Invoke constructor
super(MNISTClassifier, self).__init__()
# Set up class attributes
self.batch_size = batch_size
self.train_data_dir = train_data_dir
self.test_data_dir = test_data_dir
self.num_workers = num_workers
# Define network layers as class attributes to be used
self.conv_layer_1 = torch.nn.Sequential(
# The first block is made of a convolutional layer (3 channels, 28x28 images and a kernel mask of 5),
torch.nn.Conv2d(3,28, kernel_size=5),
# a non linear activation function
torch.nn.ReLU(),
# a maximization layer, with mask of size 2
torch.nn.MaxPool2d(kernel_size=2))
# A second block is equal to the first, except for input size which is different
self.conv_layer_2 = torch.nn.Sequential(
torch.nn.Conv2d(28,10, kernel_size=2),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2))
# A dropout layer, useful to reduce network overfitting
self.dropout1=torch.nn.Dropout(0.25)
# A fully connected layer to reduce dimensionality
self.fully_connected_1=torch.nn.Linear(250,18)
# Another fine tuning dropout layer to make network fine tune
self.dropout2=torch.nn.Dropout(0.08)
# The final fully connected layer wich output maps to the number of desired classes
self.fully_connected_2=torch.nn.Linear(18,10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment