Skip to content

Instantly share code, notes, and snippets.

@vaibhavkumar049
Created June 8, 2019 18:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vaibhavkumar049/eebdb03d1b209d65af52019d94156b46 to your computer and use it in GitHub Desktop.
Save vaibhavkumar049/eebdb03d1b209d65af52019d94156b46 to your computer and use it in GitHub Desktop.
def fit(epochs = 1000, learning_rate = 1):
loss_arr = []
acc_arr = []
for epoch in range(epochs):
y_hat = fn(X_train)
loss = F.cross_entropy(y_hat, Y_train)
loss_arr.append(loss.item())
acc_arr.append(accuracy(y_hat, Y_train))
loss.backward()
with torch.no_grad():
for param in fn.parameters():
param -= learning_rate * param.grad
fn.zero_grad()
plt.plot(loss_arr, 'r-')
plt.plot(acc_arr, 'b-')
plt.show()
print('Loss before training', loss_arr[0])
print('Loss after training', loss_arr[-1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment