Created
April 14, 2020 07:14
-
-
Save miki998/f0f8d34b23d45692852ed8885e490153 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CNN model training | |
count = 0 | |
loss_list = [] | |
iteration_list = [] | |
accuracy_list = [] | |
for epoch in range(num_epochs): | |
for i, (images, labels) in enumerate(train_loader): | |
train = Variable(images.view(100,3,16,16,16)) | |
labels = Variable(labels) | |
# Clear gradients | |
optimizer.zero_grad() | |
# Forward propagation | |
outputs = model(train) | |
# Calculate softmax and ross entropy loss | |
loss = error(outputs, labels) | |
# Calculating gradients | |
loss.backward() | |
# Update parameters | |
optimizer.step() | |
count += 1 | |
if count % 50 == 0: | |
# Calculate Accuracy | |
correct = 0 | |
total = 0 | |
# Iterate through test dataset | |
for images, labels in test_loader: | |
test = Variable(images.view(100,3,16,16,16)) | |
# Forward propagation | |
outputs = model(test) | |
# Get predictions from the maximum value | |
predicted = torch.max(outputs.data, 1)[1] | |
# Total number of labels | |
total += len(labels) | |
correct += (predicted == labels).sum() | |
accuracy = 100 * correct / float(total) | |
# store loss and iteration | |
loss_list.append(loss.data) | |
iteration_list.append(count) | |
accuracy_list.append(accuracy) | |
if count % 500 == 0: | |
# Print Loss | |
print('Iteration: {} Loss: {} Accuracy: {} %'.format(count, loss.data, accuracy)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment