Skip to content

Instantly share code, notes, and snippets.

@lucidrains
Last active February 26, 2021 19:08
Show Gist options
  • Save lucidrains/be23109bb21d9e28432cda2cb1bb343a to your computer and use it in GitHub Desktop.
Save lucidrains/be23109bb21d9e28432cda2cb1bb343a to your computer and use it in GitHub Desktop.
import torch
from torch import nn, einsum
from einops import rearrange, repeat
class EGNN(nn.Module):
def __init__(
self,
dim,
edge_dim,
m_dim = 16
):
super().__init__()
input_dim = 2 * dim + edge_dim + 1
self.edge_mlp = nn.Sequential(
nn.Linear(input_dim, input_dim * 2),
nn.ReLU(),
nn.Linear(input_dim * 2, m_dim)
)
self.coors_mlp = nn.Sequential(
nn.Linear(m_dim, m_dim * 4),
nn.ReLU(),
nn.Linear(m_dim * 4, 1)
)
self.hidden_mlp = nn.Sequential(
nn.Linear(dim + m_dim, dim * 2),
nn.ReLU(),
nn.Linear(dim * 2, dim),
)
def forward(self, feats, coors, edges = None):
b, n, d = feats.shape
rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d')
rel_dist = rel_coors.norm(dim = -1, keepdim = True)
feats_i = repeat(feats, 'b i d -> b i n d', n = n)
feats_j = repeat(feats, 'b j d -> b n j d', n = n)
edge_input = torch.cat((feats_i, feats_j, rel_dist, edges), dim = -1)
m_ij = self.edge_mlp(edge_input)
coor_weights = self.coors_mlp(m_ij)
coor_weights = rearrange(coor_weights, 'b i j () -> b i j')
coors_out = einsum('b i j, b i j c -> b i c', coor_weights, rel_coors)
m_i = m_ij.sum(dim = -2)
hidden_mlp_input = torch.cat((feats, m_i), dim = -1)
hidden_out = self.hidden_mlp(hidden_mlp_input)
return hidden_out, coors_out
layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges)
print(feats.shape, coors.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment