Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Created February 19, 2024 12:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save horoiwa/cf5d3857d7a68f18c5ba1a0171a1e592 to your computer and use it in GitHub Desktop.
Save horoiwa/cf5d3857d7a68f18c5ba1a0171a1e592 to your computer and use it in GitHub Desktop.
def compute_loss(self, x, h, edge_indices, node_masks, edge_masks):
"""
Args:
x: xyz座標, shape==(B, N, 3)
h: OneHot encoded原子タイプ, shape==(B, N, len(settings.ATOM_MAP))
edge_indices: すべての2つの原子の組み合わせ番号 shape==(B, N*N, ...)
node_masks: paddingされたダミー原子でないか, shape==(B, N, ...)
edge_masks: エッジの両端がダミー原子でないか, shape==(B, N*N, ...)
"""
# 重心が(0,0,0)となるように平行移動し、スケーリング
B, N = x.shape[0], x.shape[1]
x_0 = remove_mean(x, node_masks) / self.scale_x
h_0 = h / self.scale_h
z_0 = tf.concat([x_0, h_0], axis=-1) # (B, N, 3+4)
# 拡散タイムステップの決定: 0 <= t <= T
timesteps = tf.random.uniform(
shape=(x_0.shape[0], 1),
minval=0,
maxval=self.T+1,
dtype=tf.int32
)
t = tf.reshape(
tf.repeat(tf.cast(timesteps / self.T, tf.float32), repeats=N, axis=1),
shape=(B, N, 1)
) * node_masks
alphas_cumprod_t = tf.reshape(
tf.gather(self.alphas_cumprod, indices=timesteps),
shape=(-1, 1, 1)
)
# 順拡散プロセス
eps = sample_gaussian_noise(shape_x=x_0.shape, shape_h=h_0.shape, node_masks=node_masks)
z_t = tf.sqrt(alphas_cumprod_t) * z_0 + tf.sqrt(1.0 - alphas_cumprod_t) * eps
x_t, h_t = z_t[..., :3], z_t[..., 3:]
# 同変GCNNによるノイズ予測
eps_pred = self(x_t, h_t, t, edge_indices, node_masks, edge_masks)
loss = 0.5 * (eps - eps_pred) **2
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment