Created
February 19, 2024 12:55
-
-
Save horoiwa/cf5d3857d7a68f18c5ba1a0171a1e592 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def compute_loss(self, x, h, edge_indices, node_masks, edge_masks): | |
""" | |
Args: | |
x: xyz座標, shape==(B, N, 3) | |
h: OneHot encoded原子タイプ, shape==(B, N, len(settings.ATOM_MAP)) | |
edge_indices: すべての2つの原子の組み合わせ番号 shape==(B, N*N, ...) | |
node_masks: paddingされたダミー原子でないか, shape==(B, N, ...) | |
edge_masks: エッジの両端がダミー原子でないか, shape==(B, N*N, ...) | |
""" | |
# 重心が(0,0,0)となるように平行移動し、スケーリング | |
B, N = x.shape[0], x.shape[1] | |
x_0 = remove_mean(x, node_masks) / self.scale_x | |
h_0 = h / self.scale_h | |
z_0 = tf.concat([x_0, h_0], axis=-1) # (B, N, 3+4) | |
# 拡散タイムステップの決定: 0 <= t <= T | |
timesteps = tf.random.uniform( | |
shape=(x_0.shape[0], 1), | |
minval=0, | |
maxval=self.T+1, | |
dtype=tf.int32 | |
) | |
t = tf.reshape( | |
tf.repeat(tf.cast(timesteps / self.T, tf.float32), repeats=N, axis=1), | |
shape=(B, N, 1) | |
) * node_masks | |
alphas_cumprod_t = tf.reshape( | |
tf.gather(self.alphas_cumprod, indices=timesteps), | |
shape=(-1, 1, 1) | |
) | |
# 順拡散プロセス | |
eps = sample_gaussian_noise(shape_x=x_0.shape, shape_h=h_0.shape, node_masks=node_masks) | |
z_t = tf.sqrt(alphas_cumprod_t) * z_0 + tf.sqrt(1.0 - alphas_cumprod_t) * eps | |
x_t, h_t = z_t[..., :3], z_t[..., 3:] | |
# 同変GCNNによるノイズ予測 | |
eps_pred = self(x_t, h_t, t, edge_indices, node_masks, edge_masks) | |
loss = 0.5 * (eps - eps_pred) **2 | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment