Skip to content

Instantly share code, notes, and snippets.

@tchaton
Last active December 7, 2020 14:04
Show Gist options
  • Save tchaton/44d2373ab943934e21db535ed4d29434 to your computer and use it in GitHub Desktop.
Save tchaton/44d2373ab943934e21db535ed4d29434 to your computer and use it in GitHub Desktop.
...
from pytorch_lightning import LightningModule
from pytorch_lightning.metrics import Accuracy
class DNAConvNet(LightningModule):
def __init__(self,
num_layers: int = 2,
hidden_channels: int = 128,
heads: int = 8,
groups: int = 16,
dropout: float = 0.8,
cached: bool = False,
num_features: int = None,
num_classes: int = None,
):
super().__init__()
self.save_hyperparameters()
hparams = self.hparams
# Instantiate metrics
self.train_acc = Accuracy(hparams["num_classes"])
self.val_acc = Accuracy(hparams["num_classes"])
self.test_acc = Accuracy(hparams["num_classes"])
# Define DNA graph convolution model
self.hidden_channels = hparams["hidden_channels"]
self.lin1 = nn.Linear(hparams["num_features"],
hparams["hidden_channels"])
# Create ModuleList to hold all convolutions
self.convs = nn.ModuleList()
# Iterate through the number of layers
for _ in range(hparams["num_layers"]):
# Create a DNA Convolution - This graph convolution
# relies on MultiHead Attention mechanism
# to route information similar to Transformers.
# https://github.com/rusty1s/pytorch_geometric/
# blob/master/torch_geometric/nn/conv/dna_conv.py#L172
self.convs.append(
DNAConv(
hparams["hidden_channels"],
hparams["heads"],
hparams["groups"],
dropout=hparams["dropout"],
cached=False,
)
)
# classification MLP
self.lin2 = nn.Linear(hparams["hidden_channels"],
hparams["num_classes"],
bias=False)
def forward(self, batch):
x = F.relu(self.lin1(batch.x))
x = F.dropout(x, p=0.5, training=self.training)
x_all = x.view(-1, 1, self.hidden_channels)
# iterate over all convolutions
for idx, conv in enumerate(self.convs):
# perform convolution using previously concatenated embedding
# through edge_index
x = F.relu(conv(x_all, batch.edge_indexes[idx]))
x = x.view(-1, 1, self.hidden_channels)
# concatenate with previously computed embedding
x_all = torch.cat([x_all, x], dim=1)
# extract latest layer embedding
x = x_all[:, -1]
x = F.dropout(x, p=0.5, training=self.training)
# return logits per nodes
return F.log_softmax(self.lin2(x), -1)
def common_step(self, batch, batch_nb, stage):
# self.gather_data is a DataModule function
batch, targets = self.gather_data(batch, batch_nb)
# self call self.forward
logits = self(batch.x, batch.edge_index)
logged_acc = f"{stage}_acc"
acc_metric = getattr(self, logged_acc)
loss = F.nll_loss(logits, targets)
self.log(f"{stage}_loss", val_loss, on_step=False,
on_epoch=True, prog_bar=True)
self.log(logged_acc, acc_metric(logits, targets), on_step=False,
on_epoch=True, prog_bar=True)
if stage == "train":
return loss
def training_step(self, batch, batch_nb):
return self.common_step(batch, batch_nb, "train")
def validation_step(self, batch, batch_nb):
return self.common_step(batch, batch_nb, "val")
def test_step(self, batch, batch_nb):
return self.common_step(batch, batch_nb, "test")
def configure_optimizers(self):
return Adam(self.parameters(), lr=1e-3)
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--hidden_channels", type=int, default=128)
parser.add_argument("--heads", type=int, default=8)
parser.add_argument("--groups", type=int, default=16)
parser.add_argument("--dropout", type=float, default=0.8)
parser.add_argument("--cached", type=int, default=0)
parser.add_argument("--jit", default=True)
return parser
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment