Last active
December 7, 2020 14:04
-
-
Save tchaton/44d2373ab943934e21db535ed4d29434 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
from pytorch_lightning import LightningModule | |
from pytorch_lightning.metrics import Accuracy | |
class DNAConvNet(LightningModule): | |
def __init__(self, | |
num_layers: int = 2, | |
hidden_channels: int = 128, | |
heads: int = 8, | |
groups: int = 16, | |
dropout: float = 0.8, | |
cached: bool = False, | |
num_features: int = None, | |
num_classes: int = None, | |
): | |
super().__init__() | |
self.save_hyperparameters() | |
hparams = self.hparams | |
# Instantiate metrics | |
self.train_acc = Accuracy(hparams["num_classes"]) | |
self.val_acc = Accuracy(hparams["num_classes"]) | |
self.test_acc = Accuracy(hparams["num_classes"]) | |
# Define DNA graph convolution model | |
self.hidden_channels = hparams["hidden_channels"] | |
self.lin1 = nn.Linear(hparams["num_features"], | |
hparams["hidden_channels"]) | |
# Create ModuleList to hold all convolutions | |
self.convs = nn.ModuleList() | |
# Iterate through the number of layers | |
for _ in range(hparams["num_layers"]): | |
# Create a DNA Convolution - This graph convolution | |
# relies on MultiHead Attention mechanism | |
# to route information similar to Transformers. | |
# https://github.com/rusty1s/pytorch_geometric/ | |
# blob/master/torch_geometric/nn/conv/dna_conv.py#L172 | |
self.convs.append( | |
DNAConv( | |
hparams["hidden_channels"], | |
hparams["heads"], | |
hparams["groups"], | |
dropout=hparams["dropout"], | |
cached=False, | |
) | |
) | |
# classification MLP | |
self.lin2 = nn.Linear(hparams["hidden_channels"], | |
hparams["num_classes"], | |
bias=False) | |
def forward(self, batch): | |
x = F.relu(self.lin1(batch.x)) | |
x = F.dropout(x, p=0.5, training=self.training) | |
x_all = x.view(-1, 1, self.hidden_channels) | |
# iterate over all convolutions | |
for idx, conv in enumerate(self.convs): | |
# perform convolution using previously concatenated embedding | |
# through edge_index | |
x = F.relu(conv(x_all, batch.edge_indexes[idx])) | |
x = x.view(-1, 1, self.hidden_channels) | |
# concatenate with previously computed embedding | |
x_all = torch.cat([x_all, x], dim=1) | |
# extract latest layer embedding | |
x = x_all[:, -1] | |
x = F.dropout(x, p=0.5, training=self.training) | |
# return logits per nodes | |
return F.log_softmax(self.lin2(x), -1) | |
def common_step(self, batch, batch_nb, stage): | |
# self.gather_data is a DataModule function | |
batch, targets = self.gather_data(batch, batch_nb) | |
# self call self.forward | |
logits = self(batch.x, batch.edge_index) | |
logged_acc = f"{stage}_acc" | |
acc_metric = getattr(self, logged_acc) | |
loss = F.nll_loss(logits, targets) | |
self.log(f"{stage}_loss", val_loss, on_step=False, | |
on_epoch=True, prog_bar=True) | |
self.log(logged_acc, acc_metric(logits, targets), on_step=False, | |
on_epoch=True, prog_bar=True) | |
if stage == "train": | |
return loss | |
def training_step(self, batch, batch_nb): | |
return self.common_step(batch, batch_nb, "train") | |
def validation_step(self, batch, batch_nb): | |
return self.common_step(batch, batch_nb, "val") | |
def test_step(self, batch, batch_nb): | |
return self.common_step(batch, batch_nb, "test") | |
def configure_optimizers(self): | |
return Adam(self.parameters(), lr=1e-3) | |
@staticmethod | |
def add_argparse_args(parser): | |
parser.add_argument("--num_layers", type=int, default=2) | |
parser.add_argument("--hidden_channels", type=int, default=128) | |
parser.add_argument("--heads", type=int, default=8) | |
parser.add_argument("--groups", type=int, default=16) | |
parser.add_argument("--dropout", type=float, default=0.8) | |
parser.add_argument("--cached", type=int, default=0) | |
parser.add_argument("--jit", default=True) | |
return parser |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment