Created
March 9, 2023 16:23
-
-
Save rapharomero/6b0dcfe03d20e0e0fd6cf5fea67fa8f7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyro | |
import pyro.distributions as dist | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn as nn | |
import torch_geometric.data as data | |
import torch_geometric.transforms as T | |
from pyro.nn import PyroModule, PyroSample | |
from sklearn.metrics import roc_auc_score | |
from torch_geometric.datasets import Planetoid | |
from torch_geometric.utils import negative_sampling | |
device = 'cpu' | |
transform = T.Compose([ | |
T.NormalizeFeatures(), | |
T.ToDevice(device), | |
T.RandomLinkSplit(num_val=0.05, | |
num_test=0.1, | |
is_undirected=True, | |
add_negative_train_samples=False), | |
]) | |
path = '../../data' | |
dataset = Planetoid(path, name='Cora', transform=transform) | |
# After applying the `RandomLinkSplit` transform, the data is transformed from | |
# a data object to a list of tuples (train_data, val_data, test_data), with | |
# each element representing the corresponding split. | |
train_data, val_data, test_data = dataset[0] | |
train_data.edge_index | |
def get_labeled_edges(data): | |
# We perform a new round of negative sampling for every training epoch: | |
neg_edge_index = negative_sampling( | |
edge_index=data.edge_index, | |
num_nodes=data.num_nodes, | |
num_neg_samples=data.edge_label_index.size(1), | |
method='sparse') | |
edge_label_index = torch.cat( | |
[data.edge_label_index, neg_edge_index], | |
dim=-1, | |
) | |
edge_label = torch.cat( | |
[data.edge_label, | |
data.edge_label.new_zeros(neg_edge_index.size(1))], | |
dim=0, | |
) | |
return edge_label_index, edge_label | |
train_data.edge_label_index, train_data.edge_label = get_labeled_edges( | |
train_data) | |
pdl = data.LightningDataset(train_data) | |
class PyroEmbedding(nn.Embedding, PyroModule): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
z_loc = torch.zeros(self.num_embeddings, self.embedding_dim) | |
z_scale = torch.ones(self.num_embeddings, self.embedding_dim) | |
z_dist = dist.Normal(z_loc, z_scale).to_event(2) | |
self.weight = PyroSample(z_dist) | |
def print_param_store(): | |
print(dict(pyro.get_param_store())) | |
class DummyDataset(data.Dataset): | |
def __init__(self): | |
pass | |
def __len__(self): | |
return 1 | |
def __getitem__(self, idx): | |
return 1 | |
class PyroLVB(pl.LightningModule, PyroModule): | |
def __init__(self, train_data, dim=2, bias=0.2, optim_hparams={'lr': 0.1}): | |
super().__init__() | |
pyro.clear_param_store() | |
self.bias = nn.Parameter(torch.tensor(bias)) | |
self.z = PyroEmbedding(train_data.num_nodes, dim) | |
self.optim_hparams = optim_hparams | |
self.elbo = pyro.infer.Trace_ELBO() | |
self.guide = pyro.infer.autoguide.AutoNormal(self) | |
self.train_data = train_data | |
with pyro.poutine.trace(param_only=True): | |
self.guide(train_data) | |
self.params = [ | |
v.unconstrained() for k, v in dict(pyro.get_param_store()).items() | |
] | |
def configure_optimizers(self): | |
return torch.optim.Adam(self.params, lr=0.1, betas=(0.90, 0.999)) | |
def encode(self, x, edge_index): | |
# return self.z.weight #shape[2, batch, dim] | |
raise NotImplemented() | |
# return self.z.weight #shape[2, batch, dim] | |
def similarity(self, z1, z2): | |
return (z1 * z2).sum(dim=-1) | |
def decode(self, edge_label_index): | |
z = self.z(edge_label_index) | |
logits = self.similarity(z[0], z[1]) | |
return logits #shape [batch] | |
def forward(self, data): | |
edge_label_index, edge_label = get_labeled_edges(data) | |
logits = self.decode(edge_label_index) | |
bern = dist.Bernoulli(logits=logits) | |
with pyro.plate("edges", len(edge_label)): | |
pyro.sample("edge_label", bern, obs=edge_label) | |
return logits | |
def training_step(self, batch, batch_idx=None): | |
loss = self.elbo.differentiable_loss(self.forward, self.guide, | |
self.train_data) | |
self.log('loss', loss) | |
self.eval_step(train_data, 'train') | |
return {'loss': loss} | |
def eval_step(self, data, mode='val'): | |
edge_label_index, edge_label = get_labeled_edges(data) | |
logits = self.decode(edge_label_index) | |
self.log(f'auc_{mode}', roc_auc_score(edge_label, logits)) | |
def validation_step(self, batch, batch_idx=None): | |
self.eval_step(val_data, 'val') | |
def test_step(self, batch, batch_idx=None): | |
self.eval_step(test_data, 'test') | |
def train_dataloader(self): | |
return data.DataLoader(DummyDataset(), batch_size=1, shuffle=False) | |
def val_dataloader(self): | |
return data.DataLoader(DummyDataset(), batch_size=1, shuffle=False) | |
def test_dataloader(self): | |
return data.DataLoader(DummyDataset(), batch_size=1, shuffle=False) | |
# loss | |
plvb = PyroLVB(train_data, 2).to(device) | |
trainer = pl.Trainer(logger=pl.loggers.WandbLogger(mode='online'), | |
max_epochs=10, | |
fast_dev_run=False) | |
trainer.fit(plvb) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment