Skip to content

Instantly share code, notes, and snippets.

@tongcezhou
Forked from dvgodoy/torch101_validation.py
Created April 11, 2020 12:10
Show Gist options
  • Save tongcezhou/ced49979aa3d3cb0941a83ed49b38a8b to your computer and use it in GitHub Desktop.
Save tongcezhou/ced49979aa3d3cb0941a83ed49b38a8b to your computer and use it in GitHub Desktop.
losses = []
val_losses = []
train_step = make_train_step(model, loss_fn, optimizer)
for epoch in range(n_epochs):
for x_batch, y_batch in train_loader:
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
loss = train_step(x_batch, y_batch)
losses.append(loss)
with torch.no_grad():
for x_val, y_val in val_loader:
x_val = x_val.to(device)
y_val = y_val.to(device)
model.eval()
yhat = model(x_val)
val_loss = loss_fn(y_val, yhat)
val_losses.append(val_loss.item())
print(model.state_dict())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment