This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Initializing hyperparameters | |
num_epochs = 8 | |
num_classes = 10 | |
batch_size = 100 | |
learning_rate = 0.001 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Loading dataset | |
transform = transforms.Compose([transforms.ToTensor(), | |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]) | |
train_dataset = datasets.FashionMNIST(root='./data', | |
train=True, | |
download=True, | |
transform=transform) | |
test_dataset = datasets.FashionMNIST(root='./data', |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Loading dataset into dataloader | |
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | |
batch_size=batch_size, | |
shuffle=True) | |
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, | |
batch_size=batch_size, | |
shuffle=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Defining the network | |
class CNNModel(nn.Module): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __init__(self): | |
super(CNNModel, self).__init__() | |
#Convolution 1 | |
self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2) | |
self.relu1 = nn.ReLU() | |
#Max pool 1 | |
self.maxpool1 = nn.MaxPool2d(kernel_size=2) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, x): | |
#Convolution 1 | |
out = self.cnn1(x) | |
out = self.relu1(out) | |
#Max pool 1 | |
out = self.maxpool1(out) | |
#Convolution 2 | |
out = self.cnn2(out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Create instance of model | |
model = CNNModel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Create instance of loss | |
criterion = nn.CrossEntropyLoss() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Create instance of optimizer (Adam) | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Train the model | |
iter = 0 | |
for epoch in range(num_epochs): | |
for i, (images, labels) in enumerate(train_loader): | |
images = Variable(images) | |
labels = Variable(labels) | |
#Clear the gradients | |
optimizer.zero_grad() | |
OlderNewer