Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save mohamad-amin/608634f0688ee0f004b28c0a04c9a00f to your computer and use it in GitHub Desktop.
Save mohamad-amin/608634f0688ee0f004b28c0a04c9a00f to your computer and use it in GitHub Desktop.
Q-Learning iterated Prisoner's Dilemma with Transformers, implemented in JAX
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
import optax
from functools import partial
class TransformerBlock(nn.Module):
embed_dim: int
num_heads: int
mlp_dim: int
def setup(self):
self.self_attention = nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.embed_dim)
self.layer_norm1 = nn.LayerNorm()
self.mlp = nn.Sequential([
nn.Dense(self.mlp_dim),
nn.relu,
nn.Dense(self.embed_dim)
])
self.layer_norm2 = nn.LayerNorm()
def __call__(self, x, train: bool):
# Self-attention block
attn_output = self.self_attention(x)
x = x + attn_output
x = self.layer_norm1(x)
# Feed-forward block
mlp_output = self.mlp(x)
x = x + mlp_output
x = self.layer_norm2(x)
return x
class TransformerModel(nn.Module):
num_layers: int
embed_dim: int
num_heads: int
mlp_dim: int
num_actions: int
block_size: int
def setup(self):
self.embedding = nn.Embed(num_embeddings=self.num_actions * 2 + 1, features=self.embed_dim) # Actions + pad token
self.position_embedding = nn.Embed(num_embeddings=self.block_size, features=self.embed_dim)
self.transformer_blocks = [TransformerBlock(self.embed_dim, self.num_heads, self.mlp_dim) for _ in range(self.num_layers)]
self.final_layer = nn.Dense(self.num_actions)
def __call__(self, x, train: bool = True):
if x.ndim == 1:
x = x.reshape(1, -1)
seq_len = x.shape[1]
x = self.embedding(x)
position_indices = jnp.arange(seq_len)
position_encodings = self.position_embedding(position_indices)
x += position_encodings
for block in self.transformer_blocks:
x = block(x, train)
# Only the output corresponding to the next game prediction
return self.final_layer(x[:, -1])
def compute_loss(params, model, state, action, reward, next_state, next_action, gamma=0.99):
q_values = model.apply(params, state, train=False)[0]
next_q_values = model.apply(params, next_state, train=False)[0]
# q_values = jnp.clip(q_values, 0, 1)
# next_q_values = jnp.clip(next_q_values, 0, 1)
target = reward / MAX_PAYOFF + gamma * jnp.max(next_q_values)
loss = jnp.mean((q_values[action] - target) ** 2)
return loss
def compute_loss_2(params, model, state, action, reward, next_state, next_action, gamma=0.99):
q_values = model.apply(params, state, train=False)[0]
next_q_values = model.apply(params, next_state, train=False)[0]
target = reward / MAX_PAYOFF + gamma * jnp.max(next_q_values)
loss = jnp.mean((q_values[action] - target) ** 2)
return loss, q_values, next_q_values, target
@partial(jax.jit, static_argnums=(0,))
def _select_action(model, params, state):
q_values = model.apply(params, state, train=False)[0]
return jnp.argmax(q_values).astype(int)
@partial(jax.jit, static_argnums=(1,))
def _accumulate_gradients(params, model, opt_state, state, action, reward, next_state, gamma=0.99):
state = jnp.array(state, dtype=jnp.int32).reshape(1, -1)
next_state = jnp.array(next_state, dtype=jnp.int32).reshape(1, -1)
loss, grads = jax.value_and_grad(compute_loss)(params, model, state, action, reward, next_state, gamma)
return loss, grads
@partial(jax.jit, static_argnums=(3,))
def _update_params(params, grads, opt_state, optimizer):
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
class PrisonersDilemmaEnv:
def __init__(self):
self.num_actions = 2
self.payoff_matrix = {
(0, 0): (3, 3),
(0, 1): (0, 5),
(1, 0): (5, 0),
(1, 1): (1, 1)
}
self.reset()
def reset(self):
return jnp.zeros((1,))
def step(self, actions): # Todo: multiple actions
rewards = self.payoff_matrix[actions[0], actions[1]]
return rewards
class QLearningAgent:
def __init__(self, model, params, learning_rate=0.001, block_size=128, memory_size=128):
self.model = model
self.params = params
self.optimizer = optax.adam(learning_rate)
self.opt_state = self.optimizer.init(params)
self.block_size = block_size
self.memory_size = memory_size
self.history = []
self.gradients = []
self.losses = []
def probs(self):
state = self.get_state()
probs = jax.nn.softmax(self.model.apply(self.params, state, train=False)[0])
return probs
def select_action(self, epsilon):
state = self.get_state()
if np.random.rand() < epsilon:
result = np.random.randint(2), state
print('Choosing random!', result[0])
return result
else:
action = _select_action(self.model, self.params, state)
return action, state
def calc_rewards(self, other_action, env):
q_values = self.model.apply(self.params, self.get_state(), train=False)[0]
action = jnp.argmax(q_values).astype(int)
rewards = env.step((action, other_action))
return rewards
def get_state(self, next_action=None):
history = self.history[-self.memory_size:]
if next_action is not None:
history += [next_action]
state = jnp.array(history, dtype=jnp.int32)
padding_length = self.memory_size - len(state)
if padding_length > 0:
state = jnp.pad(state, (0, padding_length), constant_values=4) # Use 4 as the pad token.
return state.reshape(1, -1)
def get_action_pair(self, action1, action2):
return action1 * 2 + action2
def add_to_history(self, action1, action2):
action_pair = action1 * 2 + action2
self.history.append(action_pair)
def accumulate_gradients(self, state, action, reward, next_state, gamma=0.99):
loss, grads = _accumulate_gradients(self.params, self.model, self.opt_state, state, action, reward, next_state, gamma)
self.gradients.append(grads)
self.losses.append(loss)
def update(self):
mean_grads = jax.tree_util.tree_map(lambda *x: jnp.mean(jnp.stack(x), axis=0), *self.gradients)
self.params, self.opt_state = _update_params(self.params, mean_grads, self.opt_state, self.optimizer)
self.gradients = []
env = PrisonersDilemmaEnv()
# Training loop
num_episodes = 500
update_every = 10
epsilon = 0.25
for episode in range(num_episodes):
action1, state1 = p1.select_action(epsilon)
action2, state2 = p2.select_action(epsilon)
action1 = int(action1)
action2 = int(action2)
p1.add_to_history(action1, action2)
p2.add_to_history(action2, action1)
reward1, reward2 = env.step((action1, action2))
next_state1 = p1.get_state()
next_state2 = p2.get_state()
p1.accumulate_gradients(state1, action1, reward1, next_state1, gamma=GAMMA)
p2.accumulate_gradients(state2, action2, reward2, next_state2, gamma=GAMMA)
if episode % 1 == 0:
p1.update()
p2.update()
if (episode + 1) % update_every == 0:
print(f"Episode {episode + 1}: Models updated")
epsilon = max(0.001, epsilon * 0.99)
print("Training completed.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment