Skip to content

Instantly share code, notes, and snippets.

@evandiewald
Last active November 5, 2022 13:23
Show Gist options
  • Save evandiewald/94351e3803433e3ff155ce345315989a to your computer and use it in GitHub Desktop.
Save evandiewald/94351e3803433e3ff155ce345315989a to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, AGNNConv, SAGEConv, XConv
from torch_geometric.data import Data, DataLoader, Dataset
import pickle
from torch.nn import Linear, ReLU, Flatten
import random
import numpy as np
import matplotlib.pyplot as plt
from get_data import get_city_details
import pandas as pd
print('Loading dataset...')
with open('city_graph_datasets/city_data_2021_08_31_09_51_00.pkl', 'rb') as f:
data_list = pickle.load(f)
def create_dataloaders(data_list, train_ratio, val_ratio, batch_size):
random.shuffle(data_list)
data_train, data_val, data_test = data_list[0:int(train_ratio*len(data_list))], data_list[int(train_ratio*len(
data_list)):int((train_ratio + val_ratio)*len(data_list))], data_list[int((train_ratio + val_ratio)*len(data_list)):len(data_list)+1]
train_loader = DataLoader(data_train, batch_size=batch_size)
val_loader = DataLoader(data_val, batch_size=batch_size)
test_loader = DataLoader(data_test, batch_size=1)
return train_loader, val_loader, test_loader
batch_size = 8
print('Splitting dataset and creating DataLoaders...')
train_loader, val_loader, test_loader = create_dataloaders(data_list, 0.7, 0.2, batch_size)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lin1 = torch.nn.Linear(train_loader.dataset[0].num_features, 16)
self.conv1 = GCNConv(train_loader.dataset[0].num_features, 16, normalize=False)
self.conv2 = GCNConv(16, 16, normalize=False)
self.lin2 = torch.nn.Linear(16, 1)
def forward(self, data):
x, edge_index, edge_weight, pos = data.x, data.edge_index, data.edge_attr, data.pos
x = self.conv1(x, edge_index, edge_weight)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index, edge_weight)
x = F.relu(x)
x = self.lin2(x)
return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-4)
train_loss_epoch = []
val_loss_epoch = []
for epoch in range(200):
train_losses, val_losses = [], []
for batch in train_loader:
batch.to(device)
optimizer.zero_grad()
out = model(batch)
train_loss = F.mse_loss(out, batch.y.reshape((-1, 1)))
train_losses.append(train_loss.item())
train_loss.backward()
optimizer.step()
with torch.no_grad():
for val_batch in val_loader:
val_batch.to(device)
model.eval()
pred = model(val_batch)
val_loss = F.mse_loss(pred, val_batch.y.reshape((-1, 1)))
val_losses.append(val_loss.item())
print(f"Epoch: {str(epoch)}\t Train Loss: {str(np.mean(train_losses))}\t Validation Loss:"
f" {str(np.mean(val_losses))}")
print('Training complete. Evaluating on test data...')
test_losses_city = []
test_losses_hotspot = []
for test_batch in test_loader:
test_batch.to(device)
model.eval()
pred = model(test_batch)
test_losses_city.append({
'city_id': test_batch.city_id[0],
'mse_loss': float(F.mse_loss(pred, test_batch.y.reshape((-1, 1))).cpu().detach().numpy())
})
for i in range(len(test_batch.hotspots[0])):
scale_actual, scale_pred = float(test_batch.y[i].cpu().detach().numpy()), float(pred[i].cpu().detach().numpy())
test_losses_hotspot.append({
'hotspot': test_batch.hotspots[0][i],
'scale_actual': scale_actual,
'scale_predicted': scale_pred,
'difference': scale_actual - scale_pred,
'absolute_difference': np.abs(scale_pred - scale_actual)
})
df_cities = pd.DataFrame(test_losses_city).sort_values('mse_loss', ascending=False)
df_hotspots = pd.DataFrame(test_losses_hotspot).sort_values('absolute_difference', ascending=False)
print(f"Average absolute error between predicted and actual rewards scales: "
f"{np.round(np.mean(df_hotspots['absolute_difference']),3)}%")
pd.set_option('display.max_colwidth', None)
print('Hotspots with highest error:\n', df_hotspots.head(10))
print('Cities with highest error:\n', df_cities.head(10))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment