Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Last active February 18, 2024 14:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save horoiwa/2bd2a7e5c8c76ec1db106178ac0fcc3e to your computer and use it in GitHub Desktop.
Save horoiwa/2bd2a7e5c8c76ec1db106178ac0fcc3e to your computer and use it in GitHub Desktop.
import tensorflow as tf
import tensorflow.keras.layers as kl
class EquivariantGNNBlock(tf.keras.Model):
def __init__(self):
super(EquivariantGNNBlock, self).__init__()
self.dense_e = tf.keras.Sequential([
kl.Dense(256, activation=tf.nn.silu, kernel_initializer='truncated_normal'),
kl.Dense(256, activation=tf.nn.silu, kernel_initializer='truncated_normal'),
])
self.e_attention = kl.Dense(1, activation='sigmoid', kernel_initializer='truncated_normal')
self.dense_h = tf.keras.Sequential([
kl.Dense(256, activation=tf.nn.silu, kernel_initializer='truncated_normal'),
kl.Dense(256, activation=None, kernel_initializer='truncated_normal'),
])
self.dense_x = tf.keras.Sequential([
kl.Dense(256, activation=tf.nn.silu, kernel_initializer='truncated_normal'),
kl.Dense(256, activation=tf.nn.silu, kernel_initializer='truncated_normal'),
kl.Dense(1, activation=None, use_bias=True, kernel_initializer='truncated_normal'),
])
def call(self, x, h, edge_attr, edge_indices, node_mask, edge_mask):
indices_i, indices_j = edge_indices[..., 0:1], edge_indices[..., 1:2]
x_i = tf.gather_nd(x, indices_i, batch_dims=1)
x_j = tf.gather_nd(x, indices_j, batch_dims=1)
diff_ij = (x_i - x_j) * edge_mask
d_ij = tf.sqrt(tf.reduce_sum(diff_ij**2, axis=-1, keepdims=True) + 1e-8) * edge_mask
h_i = tf.gather_nd(h, indices_i, batch_dims=1)
h_j = tf.gather_nd(h, indices_j, batch_dims=1)
feat = tf.concat([h_i, h_j, d_ij, edge_attr], axis=-1) * edge_mask
h_out = self.update_h(h, feat, indices_i) * node_mask
x_out = self.update_x(x, diff_ij, d_ij, feat, indices_i) * node_mask
return x_out, h_out
def update_h(self, h_in, feat, indices_i):
m_ij = self.dense_e(feat)
e_ij = self.e_attention(m_ij)
em_ij = e_ij * m_ij
em_agg = segmnt_sum_by_node(em_ij, indices_i)
h_out = h_in + self.dense_h(tf.concat([h_in, em_agg], axis=-1))
return h_out
def update_x(self, x_in, diff_ij, d_ij, feat, indices_i):
x = self.dense_x(feat)
x = (diff_ij / (1.0 + d_ij)) * x # (B, N*N, 3) * (B, N*N, 1) -> (B, N*N, 3)
x_agg = segmnt_sum_by_node(x, indices_i)
x_out = x_in + x_agg
return x_out
def segmnt_sum_by_node(data, indices_i):
B, NN, D = data.shape
data = tf.reshape(data, shape=(B*NN, D)) # (B, NN, D) -> (B*NN, D)
indices = tf.reshape(
tf.reshape(tf.range(B), shape=(B, 1)) * settings.MAX_NUM_ATOMS + tf.squeeze(indices_i, axis=-1),
shape=(B*NN,),
)
num_segments = B * settings.MAX_NUM_ATOMS
agg = tf.reshape(
tf.math.unsorted_segment_sum(data=data, segment_ids=indices, num_segments=num_segments),
shape=(B, -1, D)
)
return agg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment