Skip to content

Instantly share code, notes, and snippets.

@tchaton
Last active December 7, 2020 14:05
Show Gist options
  • Save tchaton/268cb5cad187c78861f39b0bac229259 to your computer and use it in GitHub Desktop.
Save tchaton/268cb5cad187c78861f39b0bac229259 to your computer and use it in GitHub Desktop.
from torch.nn import Module
from torch_geometric.nn.conv import DNAConv
class DNAConvNet(Module):
def __init__(self,
num_layers: int = 2,
hidden_channels: int = 128,
heads: int = 8,
groups: int = 16,
dropout: float = 0.8,
cached: bool = False,
num_features: int = None,
num_classes: int = None,
):
super().__init__()
# perform some checks on values ...
# Define DNA graph convolution model
self.lin1 = nn.Linear(num_features, hidden_channels)
# Create ModuleList to hold all convolutions
self.convs = nn.ModuleList()
# Iterate through the number of layers
for _ in range(num_layers):
# Create a DNA Convolution - This graph convolution relies on
# a MultiHead Attention mechanism to route information similar
# to Transformers.
# https://github.com/rusty1s/pytorch_geometric/blob/
# master/torch_geometric/nn/conv/dna_conv.py#L172
self.convs.append(
DNAConv(
hidden_channels,
heads,
groups,
dropout,
cached=False,
)
)
# classification MLP
self.lin2 = nn.Linear(hidden_channels, num_classes, bias=False)
def forward(self, batch):
x = F.relu(self.lin1(batch.x))
x = F.dropout(x, p=0.5, training=self.training)
x_all = x.view(-1, 1, self.hidden_channels)
# iterate over all convolutions
for idx, conv in enumerate(self.convs):
# perform convolution using previously concatenated embedding
# through edge_index
x = F.relu(conv(x_all, batch.edge_indexes[idx]))
x = x.view(-1, 1, self.hidden_channels)
# concatenate with previously computed embedding
x_all = torch.cat([x_all, x], dim=1)
# extract latest layer embedding
x = x_all[:, -1]
x = F.dropout(x, p=0.5, training=self.training)
# return logits per nodes
return F.log_softmax(self.lin2(x), -1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment