Skip to content

Instantly share code, notes, and snippets.

@rapharomero
Created March 9, 2023 16:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rapharomero/6b0dcfe03d20e0e0fd6cf5fea67fa8f7 to your computer and use it in GitHub Desktop.
Save rapharomero/6b0dcfe03d20e0e0fd6cf5fea67fa8f7 to your computer and use it in GitHub Desktop.
import pyro
import pyro.distributions as dist
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch_geometric.data as data
import torch_geometric.transforms as T
from pyro.nn import PyroModule, PyroSample
from sklearn.metrics import roc_auc_score
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import negative_sampling
device = 'cpu'
transform = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
T.RandomLinkSplit(num_val=0.05,
num_test=0.1,
is_undirected=True,
add_negative_train_samples=False),
])
path = '../../data'
dataset = Planetoid(path, name='Cora', transform=transform)
# After applying the `RandomLinkSplit` transform, the data is transformed from
# a data object to a list of tuples (train_data, val_data, test_data), with
# each element representing the corresponding split.
train_data, val_data, test_data = dataset[0]
train_data.edge_index
def get_labeled_edges(data):
# We perform a new round of negative sampling for every training epoch:
neg_edge_index = negative_sampling(
edge_index=data.edge_index,
num_nodes=data.num_nodes,
num_neg_samples=data.edge_label_index.size(1),
method='sparse')
edge_label_index = torch.cat(
[data.edge_label_index, neg_edge_index],
dim=-1,
)
edge_label = torch.cat(
[data.edge_label,
data.edge_label.new_zeros(neg_edge_index.size(1))],
dim=0,
)
return edge_label_index, edge_label
train_data.edge_label_index, train_data.edge_label = get_labeled_edges(
train_data)
pdl = data.LightningDataset(train_data)
class PyroEmbedding(nn.Embedding, PyroModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
z_loc = torch.zeros(self.num_embeddings, self.embedding_dim)
z_scale = torch.ones(self.num_embeddings, self.embedding_dim)
z_dist = dist.Normal(z_loc, z_scale).to_event(2)
self.weight = PyroSample(z_dist)
def print_param_store():
print(dict(pyro.get_param_store()))
class DummyDataset(data.Dataset):
def __init__(self):
pass
def __len__(self):
return 1
def __getitem__(self, idx):
return 1
class PyroLVB(pl.LightningModule, PyroModule):
def __init__(self, train_data, dim=2, bias=0.2, optim_hparams={'lr': 0.1}):
super().__init__()
pyro.clear_param_store()
self.bias = nn.Parameter(torch.tensor(bias))
self.z = PyroEmbedding(train_data.num_nodes, dim)
self.optim_hparams = optim_hparams
self.elbo = pyro.infer.Trace_ELBO()
self.guide = pyro.infer.autoguide.AutoNormal(self)
self.train_data = train_data
with pyro.poutine.trace(param_only=True):
self.guide(train_data)
self.params = [
v.unconstrained() for k, v in dict(pyro.get_param_store()).items()
]
def configure_optimizers(self):
return torch.optim.Adam(self.params, lr=0.1, betas=(0.90, 0.999))
def encode(self, x, edge_index):
# return self.z.weight #shape[2, batch, dim]
raise NotImplemented()
# return self.z.weight #shape[2, batch, dim]
def similarity(self, z1, z2):
return (z1 * z2).sum(dim=-1)
def decode(self, edge_label_index):
z = self.z(edge_label_index)
logits = self.similarity(z[0], z[1])
return logits #shape [batch]
def forward(self, data):
edge_label_index, edge_label = get_labeled_edges(data)
logits = self.decode(edge_label_index)
bern = dist.Bernoulli(logits=logits)
with pyro.plate("edges", len(edge_label)):
pyro.sample("edge_label", bern, obs=edge_label)
return logits
def training_step(self, batch, batch_idx=None):
loss = self.elbo.differentiable_loss(self.forward, self.guide,
self.train_data)
self.log('loss', loss)
self.eval_step(train_data, 'train')
return {'loss': loss}
def eval_step(self, data, mode='val'):
edge_label_index, edge_label = get_labeled_edges(data)
logits = self.decode(edge_label_index)
self.log(f'auc_{mode}', roc_auc_score(edge_label, logits))
def validation_step(self, batch, batch_idx=None):
self.eval_step(val_data, 'val')
def test_step(self, batch, batch_idx=None):
self.eval_step(test_data, 'test')
def train_dataloader(self):
return data.DataLoader(DummyDataset(), batch_size=1, shuffle=False)
def val_dataloader(self):
return data.DataLoader(DummyDataset(), batch_size=1, shuffle=False)
def test_dataloader(self):
return data.DataLoader(DummyDataset(), batch_size=1, shuffle=False)
# loss
plvb = PyroLVB(train_data, 2).to(device)
trainer = pl.Trainer(logger=pl.loggers.WandbLogger(mode='online'),
max_epochs=10,
fast_dev_run=False)
trainer.fit(plvb)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment