Last active
December 7, 2020 14:05
-
-
Save tchaton/268cb5cad187c78861f39b0bac229259 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.nn import Module | |
from torch_geometric.nn.conv import DNAConv | |
class DNAConvNet(Module): | |
def __init__(self, | |
num_layers: int = 2, | |
hidden_channels: int = 128, | |
heads: int = 8, | |
groups: int = 16, | |
dropout: float = 0.8, | |
cached: bool = False, | |
num_features: int = None, | |
num_classes: int = None, | |
): | |
super().__init__() | |
# perform some checks on values ... | |
# Define DNA graph convolution model | |
self.lin1 = nn.Linear(num_features, hidden_channels) | |
# Create ModuleList to hold all convolutions | |
self.convs = nn.ModuleList() | |
# Iterate through the number of layers | |
for _ in range(num_layers): | |
# Create a DNA Convolution - This graph convolution relies on | |
# a MultiHead Attention mechanism to route information similar | |
# to Transformers. | |
# https://github.com/rusty1s/pytorch_geometric/blob/ | |
# master/torch_geometric/nn/conv/dna_conv.py#L172 | |
self.convs.append( | |
DNAConv( | |
hidden_channels, | |
heads, | |
groups, | |
dropout, | |
cached=False, | |
) | |
) | |
# classification MLP | |
self.lin2 = nn.Linear(hidden_channels, num_classes, bias=False) | |
def forward(self, batch): | |
x = F.relu(self.lin1(batch.x)) | |
x = F.dropout(x, p=0.5, training=self.training) | |
x_all = x.view(-1, 1, self.hidden_channels) | |
# iterate over all convolutions | |
for idx, conv in enumerate(self.convs): | |
# perform convolution using previously concatenated embedding | |
# through edge_index | |
x = F.relu(conv(x_all, batch.edge_indexes[idx])) | |
x = x.view(-1, 1, self.hidden_channels) | |
# concatenate with previously computed embedding | |
x_all = torch.cat([x_all, x], dim=1) | |
# extract latest layer embedding | |
x = x_all[:, -1] | |
x = F.dropout(x, p=0.5, training=self.training) | |
# return logits per nodes | |
return F.log_softmax(self.lin2(x), -1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment