Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Created November 28, 2022 11:52
Show Gist options
  • Save horoiwa/e19a7bc11d6c592b528471bf8238e478 to your computer and use it in GitHub Desktop.
Save horoiwa/e19a7bc11d6c592b528471bf8238e478 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as kl
import tensorflow_probability as tfp
class DecisionTransformer(tf.keras.Model):
def __init__(self, action_space, max_timestep, context_length=30,
n_blocks=6, n_heads=8, embed_dim=128):
super(DecisionTransformer, self).__init__()
self.state_shape = (84, 84, 4)
self.action_space = action_space
self.context_length = context_length
self.embed_dim = embed_dim
self.rtgs_embedding = kl.Dense(
self.embed_dim, activation=None,
kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))
self.state_embedding = StateEmbedding(self.embed_dim, input_shape=self.state_shape)
self.action_embedding = kl.Embedding(self.action_space, self.embed_dim,
embeddings_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))
self.pos_embedding = PositionalEmbedding(
max_timestep=max_timestep,
context_length=context_length,
embed_dim=embed_dim)
self.dropout = kl.Dropout(0.1)
self.blocks = [DecoderBlock(n_heads, embed_dim, context_length) for _ in range(n_blocks)]
self.layer_norm = kl.LayerNormalization()
self.head = kl.Dense(
self.action_space, use_bias=False, activation=None,
kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))
@tf.function
def call(self, rtgs, states, actions, timesteps, training=False):
"""
Args:
rtgs: dtype=tf.float32, shape=(B, L, 1)
states: dtype=tf.float32, shape=(B, L, 84, 84, 4)
actions dtype=tf.uint8, shape=(B, L, 1)
timesteps dtype=tf.int32, shape=(B, 1, 1)
"""
B, L = rtgs.shape[0], rtgs.shape[1]
rtgs_embed = tf.math.tanh(self.rtgs_embedding(rtgs)) # (B, L, embed_dim)
states_embed = tf.math.tanh(self.state_embedding(states)) # (B, L, embed_dim)
action_embed = tf.math.tanh(
self.action_embedding(tf.squeeze(actions, axis=-1))
) # (B, L, 1) -> (B, L) -> (B, L, embed_dim)
pos_embed = self.pos_embedding(timesteps, L) # (B, 3L, embed_dim)
tokens = tf.stack(
[rtgs_embed, states_embed, action_embed], axis=1) # (B, 3, L, embed_dim)
tokens = tf.reshape(
tf.transpose(tokens, (0, 2, 1, 3)),
(B, 3*L, self.embed_dim)) # (B, 3L, embed_dim)
x = self.dropout(tokens + pos_embed, training=training)
for block in self.blocks:
x = block(x, training=training)
x = self.layer_norm(x)
logits = self.head(x) # (B, 3L, action_space)
# use only predictions from state
logits = logits[:, 1::3, :] # (B, L, action_space)
return logits
def sample_action(self, rtgs: list, states: list, actions: list, timestep: int) -> int:
assert len(rtgs) == len(states) == len(actions) + 1
L = min(len(rtgs), self.context_length)
rtgs = tf.reshape(
tf.convert_to_tensor(rtgs, dtype=tf.float32), shape=[1, L, 1])
states = tf.expand_dims(tf.stack(states, axis=0), 0)
# 実装の都合上ダミーアクションを最後尾に加える
actions = actions + [0]
actions = tf.reshape(
tf.convert_to_tensor(actions, dtype=tf.uint8), [1, L, 1])
timestep = tf.reshape(
tf.convert_to_tensor([timestep], dtype=tf.int32), [1, 1, 1])
logits_all = self(rtgs, states, actions, timestep) # (1, L, A)
logits = logits_all[0, -1, :]
probs = tf.nn.softmax(logits)
dist = tfp.distributions.Categorical(probs=probs)
sampled_action = dist.sample().numpy()
return sampled_action, probs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment