Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Created July 11, 2023 08:05
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save horoiwa/db3d744e4845db0890a63f26a99e43f5 to your computer and use it in GitHub Desktop.
Save horoiwa/db3d744e4845db0890a63f26a99e43f5 to your computer and use it in GitHub Desktop.
class DiffusionPolicy(tf.keras.Model):
def __init__(self, action_space: int):
super(DiffusionPolicy, self).__init__()
self.n_timesteps = 5
self.action_space = action_space
self.time_embedding = SinusoidalPositionalEmbedding(L=self.n_timesteps, D=12)
self.dense1 = kl.Dense(256, activation=mish)
self.dense2 = kl.Dense(256, activation=mish)
self.dense3 = kl.Dense(256, activation=mish)
self.out = kl.Dense(self.action_space, activation=None)
self.alphas, self.betas = get_noise_schedule(T=self.n_timesteps)
self.alphas_cumprod = tf.math.cumprod(self.alphas)
self.alphas_cumprod_prev = tf.concat([[1.], self.alphas_cumprod[:-1]], axis=0)
self.variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
def call(self, x, timesteps, states):
t = self.time_embedding(timesteps)
x = tf.concat([x, t, states], axis=1)
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
eps = self.out(x)
return eps
@tf.function
def compute_bc_loss(self, actions, states):
x_0 = actions
batch_size = x_0.shape[0]
timesteps = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=self.n_timesteps, dtype=tf.int32),
alphas_cumprod_t = tf.reshape(tf.gather(self.alphas_cumprod, indices=timesteps), (-1, 1)) # (1, B, 1) -> (B, 1)
eps = tf.random.normal(shape=x_0.shape, mean=0., stddev=1.)
x_t = tf.sqrt(alphas_cumprod_t) * x_0 + tf.sqrt(1. - alphas_cumprod_t) * eps
eps_pred = self(x_t, timesteps, states)
bc_loss = tf.reduce_mean(tf.square(eps - eps_pred))
return bc_loss
@tf.function
def sample_actions(self, states):
batch_size = states.shape[0]
x_t = tf.random.normal(shape=(batch_size, self.action_space), mean=0., stddev=1.)
for t in reversed(range(0, self.n_timesteps)):
t = t * tf.ones(shape=(batch_size, 1), dtype=tf.int32) # (B, 1)
x_t = self.inv_diffusion(x_t, t, states)
x_0 = tf.clip_by_value(x_t, -1.0, 1.0)
return x_0
def inv_diffusion(self, x_t, t, states):
beta_t = tf.reshape(tf.gather(self.betas, indices=t), (-1, 1)) # (1, B, 1) -> (B, 1)
alphas_cumprod_t = tf.reshape(tf.gather(self.alphas_cumprod, indices=t), (-1, 1)) # (1, B, 1) -> (B, 1)
eps_t = self(x_t, t, states)
mu = (1.0 / tf.sqrt(1.0 - beta_t)) * (x_t - (beta_t / tf.sqrt(1.0 - alphas_cumprod_t)) * eps_t)
sigma = tf.sqrt(tf.reshape(tf.gather(self.variance, indices=t), (-1, 1)))
noise = tf.random.normal(shape=x_t.shape, mean=0., stddev=1.)
x_t_minus_1 = mu + sigma * noise
return x_t_minus_1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment