Created
November 28, 2022 11:52
-
-
Save horoiwa/e19a7bc11d6c592b528471bf8238e478 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import tensorflow.keras.layers as kl | |
import tensorflow_probability as tfp | |
class DecisionTransformer(tf.keras.Model): | |
def __init__(self, action_space, max_timestep, context_length=30, | |
n_blocks=6, n_heads=8, embed_dim=128): | |
super(DecisionTransformer, self).__init__() | |
self.state_shape = (84, 84, 4) | |
self.action_space = action_space | |
self.context_length = context_length | |
self.embed_dim = embed_dim | |
self.rtgs_embedding = kl.Dense( | |
self.embed_dim, activation=None, | |
kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)) | |
self.state_embedding = StateEmbedding(self.embed_dim, input_shape=self.state_shape) | |
self.action_embedding = kl.Embedding(self.action_space, self.embed_dim, | |
embeddings_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)) | |
self.pos_embedding = PositionalEmbedding( | |
max_timestep=max_timestep, | |
context_length=context_length, | |
embed_dim=embed_dim) | |
self.dropout = kl.Dropout(0.1) | |
self.blocks = [DecoderBlock(n_heads, embed_dim, context_length) for _ in range(n_blocks)] | |
self.layer_norm = kl.LayerNormalization() | |
self.head = kl.Dense( | |
self.action_space, use_bias=False, activation=None, | |
kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)) | |
@tf.function | |
def call(self, rtgs, states, actions, timesteps, training=False): | |
""" | |
Args: | |
rtgs: dtype=tf.float32, shape=(B, L, 1) | |
states: dtype=tf.float32, shape=(B, L, 84, 84, 4) | |
actions dtype=tf.uint8, shape=(B, L, 1) | |
timesteps dtype=tf.int32, shape=(B, 1, 1) | |
""" | |
B, L = rtgs.shape[0], rtgs.shape[1] | |
rtgs_embed = tf.math.tanh(self.rtgs_embedding(rtgs)) # (B, L, embed_dim) | |
states_embed = tf.math.tanh(self.state_embedding(states)) # (B, L, embed_dim) | |
action_embed = tf.math.tanh( | |
self.action_embedding(tf.squeeze(actions, axis=-1)) | |
) # (B, L, 1) -> (B, L) -> (B, L, embed_dim) | |
pos_embed = self.pos_embedding(timesteps, L) # (B, 3L, embed_dim) | |
tokens = tf.stack( | |
[rtgs_embed, states_embed, action_embed], axis=1) # (B, 3, L, embed_dim) | |
tokens = tf.reshape( | |
tf.transpose(tokens, (0, 2, 1, 3)), | |
(B, 3*L, self.embed_dim)) # (B, 3L, embed_dim) | |
x = self.dropout(tokens + pos_embed, training=training) | |
for block in self.blocks: | |
x = block(x, training=training) | |
x = self.layer_norm(x) | |
logits = self.head(x) # (B, 3L, action_space) | |
# use only predictions from state | |
logits = logits[:, 1::3, :] # (B, L, action_space) | |
return logits | |
def sample_action(self, rtgs: list, states: list, actions: list, timestep: int) -> int: | |
assert len(rtgs) == len(states) == len(actions) + 1 | |
L = min(len(rtgs), self.context_length) | |
rtgs = tf.reshape( | |
tf.convert_to_tensor(rtgs, dtype=tf.float32), shape=[1, L, 1]) | |
states = tf.expand_dims(tf.stack(states, axis=0), 0) | |
# 実装の都合上ダミーアクションを最後尾に加える | |
actions = actions + [0] | |
actions = tf.reshape( | |
tf.convert_to_tensor(actions, dtype=tf.uint8), [1, L, 1]) | |
timestep = tf.reshape( | |
tf.convert_to_tensor([timestep], dtype=tf.int32), [1, 1, 1]) | |
logits_all = self(rtgs, states, actions, timestep) # (1, L, A) | |
logits = logits_all[0, -1, :] | |
probs = tf.nn.softmax(logits) | |
dist = tfp.distributions.Categorical(probs=probs) | |
sampled_action = dist.sample().numpy() | |
return sampled_action, probs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment