Skip to content

Instantly share code, notes, and snippets.

Created April 17, 2019 17:17
Show Gist options
  • Save averdones/ff8e2c04962f585168278d73e4b4a48a to your computer and use it in GitHub Desktop.
Save averdones/ff8e2c04962f585168278d73e4b4a48a to your computer and use it in GitHub Desktop.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from import Dataset, DataLoader, random_split
# --------------- Dataset ---------------
class StudentsPerformanceDataset(Dataset):
"""Students Performance dataset."""
def __init__(self, csv_file):
"""Initializes instance of class StudentsPerformanceDataset.
csv_file (str): Path to the csv file with the students data.
df = pd.read_csv(csv_file)
# Grouping variable names
self.categorical = ["gender", "race/ethnicity", "parental level of education", "lunch",
"test preparation course"] = "math score"
# One-hot encoding of categorical variables
self.students_frame = pd.get_dummies(df, prefix=self.categorical)
# Save target and predictors
self.X = self.students_frame.drop(, axis=1)
self.y = self.students_frame[]
def __len__(self):
return len(self.students_frame)
def __getitem__(self, idx):
# Convert idx from tensor to list due to pandas bug (that arises when using pytorch's random_split)
if isinstance(idx, torch.Tensor):
idx = idx.tolist()
return [self.X.iloc[idx].values, self.y[idx]]
# --------------- Model ---------------
class Net(nn.Module):
def __init__(self, D_in, H=15, D_out=1):
self.fc1 = nn.Linear(D_in, H)
self.fc2 = nn.Linear(H, D_out)
self.relu = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x.squeeze()
# --------------- Training ---------------
def train(csv_file, n_epochs=100):
"""Trains the model.
csv_file (str): Absolute path of the dataset used for training.
n_epochs (int): Number of epochs to train.
# Load dataset
dataset = StudentsPerformanceDataset(csv_file)
# Split into training and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
trainset, testset = random_split(dataset, [train_size, test_size])
# Dataloaders
trainloader = DataLoader(trainset, batch_size=200, shuffle=True)
testloader = DataLoader(testset, batch_size=200, shuffle=False)
# Use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Define the model
D_in, H = 19, 15
net = Net(D_in, H).to(device)
# Loss function
criterion = nn.MSELoss()
# Optimizer
optimizer = optim.Adam(net.parameters(), weight_decay=0.0001)
# Train the net
loss_per_iter = []
loss_per_batch = []
for epoch in range(n_epochs):
running_loss = 0.0
for i, (inputs, labels) in enumerate(trainloader):
inputs =
labels =
# Zero the parameter gradients
# Forward + backward + optimize
outputs = net(inputs.float())
loss = criterion(outputs, labels.float())
# Save loss to plot
running_loss += loss.item()
loss_per_batch.append(running_loss / (i + 1))
running_loss = 0.0
# Comparing training to test
dataiter = iter(testloader)
inputs, labels =
inputs =
labels =
outputs = net(inputs.float())
print("Root mean squared error")
print("Training:", np.sqrt(loss_per_batch[-1]))
print("Test", np.sqrt(criterion(labels.float(), outputs).detach().cpu().numpy()))
# Plot training loss curve
plt.plot(np.arange(len(loss_per_iter)), loss_per_iter, "-", alpha=0.5, label="Loss per epoch")
plt.plot(np.arange(len(loss_per_iter), step=4) + 3, loss_per_batch, ".-", label="Loss per mini-batch")
plt.xlabel("Number of epochs")
if __name__ == "__main__":
import os
import sys
import argparse
# By default, read csv file in the same directory as this script
csv_file = os.path.join(sys.path[0], "StudentsPerformance.csv")
# Parsing arguments
parser = argparse.ArgumentParser()
parser.add_argument("--file", "-f", nargs="?", const=csv_file, default=csv_file,
help="Dataset file used for training")
parser.add_argument("--epochs", "-e", type=int, nargs="?", default=100, help="Number of epochs to train")
args = parser.parse_args()
# Call the main function of the script
train(args.file, args.epochs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment