Skip to content

Instantly share code, notes, and snippets.

@amankharwal
Created September 21, 2020 01:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amankharwal/c2f973456232e1a2f2ecc99728fd75f9 to your computer and use it in GitHub Desktop.
Save amankharwal/c2f973456232e1a2f2ecc99728fd75f9 to your computer and use it in GitHub Desktop.
input_size = len(input_cols)
output_size = len(output_cols)
class CarsModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(input_size, output_size) # fill this (hint: use input_size & output_size defined above)
def forward(self, xb):
out = self.linear(xb) # fill this
return out
def training_step(self, batch):
inputs, targets = batch
# Generate predictions
out = self(inputs)
# Calcuate loss
loss = F.l1_loss(out, targets) # fill this
return loss
def validation_step(self, batch):
inputs, targets = batch
# Generate predictions
out = self(inputs)
# Calculate loss
loss = F.l1_loss(out, targets) # fill this
return {'val_loss': loss.detach()}
def validation_epoch_end(self, outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean() # Combine losses
return {'val_loss': epoch_loss.item()}
def epoch_end(self, epoch, result, num_epochs):
# Print result every 20th epoch
if (epoch+1) % 20 == 0 or epoch == num_epochs-1:
print("Epoch [{}], val_loss: {:.4f}".format(epoch+1, result['val_loss']))
model = CarsModel()
list(model.parameters())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment